mirror of https://github.com/llvm/torch-mlir
Remove acap_dispatch.
This is old code that barely worked, and this approach just won't scale. TorchFX seems like the tracing-based solution going forward.pull/309/head
parent
5f3b1ce0b8
commit
900f0e04aa
|
@ -13,10 +13,7 @@ include_directories(BEFORE
|
|||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
add_library(TorchMLIRTorchPlugin SHARED
|
||||
builder/acap_dispatch.cpp
|
||||
builder/class_annotator.cpp
|
||||
builder/debug.cpp
|
||||
builder/func_builder.cpp
|
||||
builder/function_importer.cpp
|
||||
builder/module_builder.cpp
|
||||
builder/node_importer.cpp
|
||||
|
|
|
@ -1,595 +0,0 @@
|
|||
//===- acap_dispatch.cpp --------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "acap_dispatch.h"
|
||||
#include "debug.h"
|
||||
#include "mlir_utils.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using c10::FunctionSchema;
|
||||
using c10::IValue;
|
||||
using c10::OperatorHandle;
|
||||
using c10::Stack;
|
||||
|
||||
// TODO: Private use dispatch keys are not made for real uses. Allocate a proper
|
||||
// dispatch key in upstream PyTorch (DispatchKey.h) prior to maturity. Note
|
||||
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
|
||||
// enum value, so we define both. We can get rid of both once we have our
|
||||
// own key.
|
||||
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
|
||||
// through the autograd keys.
|
||||
#define ACAP_DISPATCH_KEY PrivateUse2
|
||||
#define ACAP_GRAD_DISPATCH_KEY AutogradPrivateUse2
|
||||
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
||||
static c10::DispatchKey kAcapGradDispatchKey =
|
||||
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
|
||||
|
||||
AcapController::TracedSchemaOpBuilder::TracedSchemaOpBuilder(
|
||||
AcapController &parent, MlirContext context, MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle)
|
||||
: parent(parent), loc(loc), opHandle(opHandle) {}
|
||||
|
||||
void AcapController::TracedSchemaOpBuilder::addOperand(const IValue &value) {
|
||||
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
|
||||
if (mlirValueIsNull(mlirValue)) {
|
||||
std::stringstream out;
|
||||
const std::string &kernelName = opHandle.operator_name().name;
|
||||
out << "Unsupported capture value passed to kernel '" << kernelName << "' ("
|
||||
<< value.tagKind() << "): " << value;
|
||||
throw std::invalid_argument(out.str());
|
||||
}
|
||||
operands.push_back(mlirValue);
|
||||
}
|
||||
|
||||
void AcapController::TracedSchemaOpBuilder::addResult(const IValue &value) {
|
||||
MlirType resultType = parent.mapIValueToMlirType(loc, value);
|
||||
if (mlirTypeIsNull(resultType)) {
|
||||
std::stringstream out;
|
||||
const std::string &kernelName = opHandle.operator_name().name;
|
||||
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||
<< "' (" << value.tagKind() << "): " << value;
|
||||
throw std::invalid_argument(out.str());
|
||||
}
|
||||
if (value.isTensor()) {
|
||||
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
|
||||
}
|
||||
resultTypes.push_back(resultType);
|
||||
}
|
||||
|
||||
MlirOperation AcapController::TracedSchemaOpBuilder::create() {
|
||||
MlirOperation op =
|
||||
createOperationFromSchema(parent.funcBuilder->getEntryBlock(), loc,
|
||||
opHandle.schema(), resultTypes, operands);
|
||||
// Map result tensors.
|
||||
for (auto &it : resultIndexToTensorMap) {
|
||||
MlirValue result = mlirOperationGetResult(op, it.first);
|
||||
parent.funcBuilder->mapTensor(it.second, result);
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
std::list<AcapController::Activation> &
|
||||
AcapController::getThreadLocalActiveStack() {
|
||||
static thread_local std::list<Activation> threadLocalActiveStack;
|
||||
return threadLocalActiveStack;
|
||||
}
|
||||
|
||||
py::object AcapController::contextEnter() {
|
||||
auto &stack = getThreadLocalActiveStack();
|
||||
stack.emplace_front(shared_from_this());
|
||||
Activation ¤t = stack.front();
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
current.includeGuard =
|
||||
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(keySet);
|
||||
return py::cast(this);
|
||||
}
|
||||
|
||||
void AcapController::contextExit(py::object exc_type, py::object exc_val,
|
||||
py::object exc_tb) {
|
||||
auto &stack = getThreadLocalActiveStack();
|
||||
if (stack.empty() || stack.front().controller.get() != this) {
|
||||
throw std::runtime_error("Mismatched context manager __exit__");
|
||||
}
|
||||
stack.pop_front();
|
||||
|
||||
if (!hasReturned) {
|
||||
returns({});
|
||||
}
|
||||
}
|
||||
|
||||
void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||
verifyHasNotReturned();
|
||||
|
||||
std::vector<MlirType> returnsTypes;
|
||||
std::vector<MlirValue> returnsValues;
|
||||
for (auto &tensor : tensors) {
|
||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||
if (mlirValueIsNull(v)) {
|
||||
debugTrace(
|
||||
"Return of imported-constant tensor (intentional memorization?)");
|
||||
v = importTensorByValue(tensor);
|
||||
}
|
||||
|
||||
returnsTypes.push_back(mlirValueGetType(v));
|
||||
returnsValues.push_back(v);
|
||||
}
|
||||
|
||||
MlirLocation loc = getCurrentLocation();
|
||||
OperationStateHolder s("std.return", loc);
|
||||
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
|
||||
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(
|
||||
s.createOperation());
|
||||
funcBuilder->rewriteFuncReturnTypes(returnsTypes);
|
||||
hasReturned = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AcapController>
|
||||
AcapController::getCurrentThreadAcapController() {
|
||||
auto &stack = getThreadLocalActiveStack();
|
||||
if (stack.empty())
|
||||
return nullptr;
|
||||
return stack.front().controller;
|
||||
}
|
||||
|
||||
void AcapController::verifyHasNotReturned() {
|
||||
if (hasReturned) {
|
||||
throw std::runtime_error(
|
||||
"Function has already returned. Cannot trace more operations.");
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||
Stack *stack) {
|
||||
auto redispatchCallback = [&]() {
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
// Passthrough.
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
dispatcher.callBoxed(opHandle, stack);
|
||||
};
|
||||
|
||||
auto current = getCurrentThreadAcapController();
|
||||
if (!current) {
|
||||
redispatchCallback();
|
||||
return;
|
||||
}
|
||||
current->fallbackKernelImpl(opHandle, stack, redispatchCallback);
|
||||
}
|
||||
|
||||
at::Tensor AcapController::convolutionKernel(
|
||||
const at::Tensor &input, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias, const at::IntArrayRef stride,
|
||||
const at::IntArrayRef padding, const at::IntArrayRef dilation,
|
||||
const bool transposed, const at::IntArrayRef output_padding,
|
||||
const int64_t groups) {
|
||||
static c10::OperatorName opName{"aten::convolution", ""};
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find convolution op");
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "Convolution (unboxed) dispatch: " << opHandle->schema();
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
auto opTyped = opHandle->typed<at::Tensor(
|
||||
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
|
||||
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
|
||||
const bool, const at::IntArrayRef, const int64_t)>();
|
||||
|
||||
// Exclude recursive calls: convolution is completely emitted by this
|
||||
// kernel.
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
auto current = getCurrentThreadAcapController();
|
||||
if (!current) {
|
||||
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input,
|
||||
weight, bias, stride, padding, dilation,
|
||||
transposed, output_padding, groups);
|
||||
}
|
||||
|
||||
MlirContext context = current->funcBuilder->getContext();
|
||||
MlirLocation loc = current->getCurrentLocation();
|
||||
std::string kernelName{"aten::convolution"};
|
||||
TracedSchemaOpBuilder opBuilder{*current, context, loc, *opHandle};
|
||||
|
||||
opBuilder.addOperand(IValue(input));
|
||||
opBuilder.addOperand(IValue(weight));
|
||||
// This is really sad: instead of storing a none in the optional, it stores
|
||||
// an undefined tensor, which cannot convert to an IValue :(
|
||||
// TODO: File PyTorch bug. Perhaps this is why they don't support boxing
|
||||
// for it.
|
||||
IValue biasIValue;
|
||||
if (bias && bias->defined()) {
|
||||
biasIValue = IValue(bias);
|
||||
} else {
|
||||
biasIValue = IValue(c10::optional<at::Tensor>());
|
||||
}
|
||||
opBuilder.addOperand(biasIValue);
|
||||
opBuilder.addOperand(IValue(stride));
|
||||
opBuilder.addOperand(IValue(padding));
|
||||
opBuilder.addOperand(IValue(dilation));
|
||||
opBuilder.addOperand(IValue(transposed));
|
||||
opBuilder.addOperand(IValue(output_padding));
|
||||
opBuilder.addOperand(IValue(groups));
|
||||
|
||||
auto result = opTyped.redispatch(
|
||||
c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input, weight, bias, stride, padding,
|
||||
dilation, transposed, output_padding, groups);
|
||||
opBuilder.addResult(result);
|
||||
opBuilder.create();
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
AcapController::mklConvolutionBackward(
|
||||
const at::Tensor &input, const at::Tensor &grad_output,
|
||||
const at::Tensor &weight, const at::IntArrayRef padding,
|
||||
const at::IntArrayRef stride, const at::IntArrayRef dilation,
|
||||
const int64_t groups, std::array<bool, 3> output_mask) {
|
||||
static c10::OperatorName opName{"aten::mkldnn_convolution_backward", ""};
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find mkldnn_convolution_backward op");
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "mkldnn_convolution_backward dispatch: " << opHandle->schema();
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
auto opTyped = opHandle->typed<std::tuple<at::Tensor, at::Tensor, at::Tensor>(
|
||||
const at::Tensor &input, const at::Tensor &grad_output,
|
||||
const at::Tensor &weight, const at::IntArrayRef padding,
|
||||
const at::IntArrayRef stride, const at::IntArrayRef dilation,
|
||||
const int64_t groups, std::array<bool, 3> output_mask)>();
|
||||
|
||||
// Exclude recursive calls: convolution is completely emitted by this
|
||||
// kernel.
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
auto current = getCurrentThreadAcapController();
|
||||
if (!current) {
|
||||
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input,
|
||||
grad_output, weight, padding, stride,
|
||||
dilation, groups, output_mask);
|
||||
}
|
||||
|
||||
// Emit the call as if to aten::convolution_overridable, the generic, full
|
||||
// parameterized versions that backends are supposed to implement.
|
||||
// Requires some parameter swizzling.
|
||||
// It has the signature:
|
||||
// convolution_backward_overrideable(Tensor grad_output, Tensor input,
|
||||
// Tensor weight, int[] stride, int[] padding, int[] dilation,
|
||||
// bool transposed, int[] output_padding, int groups,
|
||||
// bool[3] output_mask) ->
|
||||
// (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
MlirContext context = current->funcBuilder->getContext();
|
||||
MlirLocation loc = current->getCurrentLocation();
|
||||
std::string kernelName{"aten::convolution_backward"};
|
||||
static c10::OperatorName emitOpName{"aten::convolution_backward_overrideable",
|
||||
""};
|
||||
auto emitOpHandle = dispatcher.findOp(emitOpName);
|
||||
assert(emitOpHandle && "could not find convolution_backward_overrideable op");
|
||||
TracedSchemaOpBuilder opBuilder{*current, context, loc, *emitOpHandle};
|
||||
|
||||
opBuilder.addOperand(IValue(grad_output));
|
||||
opBuilder.addOperand(IValue(input));
|
||||
opBuilder.addOperand(IValue(weight));
|
||||
opBuilder.addOperand(IValue(stride));
|
||||
opBuilder.addOperand(IValue(padding));
|
||||
opBuilder.addOperand(IValue(dilation));
|
||||
opBuilder.addOperand(IValue(false));
|
||||
std::vector<int64_t> output_padding(padding.size()); // Not provided.
|
||||
opBuilder.addOperand(IValue(at::IntArrayRef(output_padding)));
|
||||
opBuilder.addOperand(IValue(groups));
|
||||
opBuilder.addOperand(IValue(output_mask));
|
||||
|
||||
auto results = opTyped.redispatch(
|
||||
c10::DispatchKeySet({c10::DispatchKey::AutogradCPU}), input, grad_output, weight, padding,
|
||||
stride, dilation, groups, output_mask);
|
||||
|
||||
opBuilder.addResult(std::get<0>(results));
|
||||
opBuilder.addResult(std::get<1>(results));
|
||||
opBuilder.addResult(std::get<2>(results));
|
||||
opBuilder.create();
|
||||
return results;
|
||||
}
|
||||
|
||||
at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
|
||||
const at::Tensor &src,
|
||||
bool non_blocking) {
|
||||
static c10::OperatorName opName{"aten::copy_", ""};
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find copy_ op");
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "copy_ dispatch: " << opHandle->schema();
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
auto opTyped = opHandle->typed<at::Tensor &(
|
||||
at::Tensor & self, const at::Tensor &src, bool non_blocking)>();
|
||||
|
||||
// Exclude recursive calls.
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
auto current = getCurrentThreadAcapController();
|
||||
if (!current) {
|
||||
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), self,
|
||||
src, non_blocking);
|
||||
}
|
||||
|
||||
MlirContext context = current->funcBuilder->getContext();
|
||||
MlirLocation loc = current->getCurrentLocation();
|
||||
TracedSchemaOpBuilder opBuilder{*current, context, loc, *opHandle};
|
||||
|
||||
opBuilder.addOperand(IValue(self));
|
||||
opBuilder.addOperand(IValue(src));
|
||||
auto &result = opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::CPU}), self, src,
|
||||
non_blocking);
|
||||
opBuilder.addResult(result);
|
||||
opBuilder.create();
|
||||
return result;
|
||||
}
|
||||
|
||||
at::Tensor AcapController::arangeBackendSelectKernel(
|
||||
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
static c10::OperatorName opName{"aten::arange", ""};
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find arange op");
|
||||
|
||||
// Exclude recursive calls.
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
// Dispatching in this fashion replicates the exact way that PyTorch
|
||||
// built-in handlers dispatch to BackendSelect kernels.
|
||||
auto targetDk = c10::computeDispatchKey(dtype, layout, device);
|
||||
auto opTyped = opHandle->typed<at::Tensor(
|
||||
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory)>();
|
||||
return opTyped.redispatch(c10::DispatchKeySet({targetDk}), end, dtype, layout, device,
|
||||
pin_memory);
|
||||
}
|
||||
|
||||
MlirLocation AcapController::getCurrentLocation() {
|
||||
return mlirLocationUnknownGet(funcBuilder->getContext());
|
||||
}
|
||||
|
||||
void AcapController::fallbackKernelImpl(
|
||||
const OperatorHandle &opHandle, Stack *stack,
|
||||
std::function<void()> redispatchCallback) {
|
||||
verifyHasNotReturned();
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "Fallback (boxed) dispatch: " << opHandle.schema()
|
||||
<< " (stack size=" << stack->size() << ")";
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
|
||||
const FunctionSchema &schema = opHandle.schema();
|
||||
|
||||
// Check for unsupported.
|
||||
if (schema.is_vararg() || schema.is_varret()) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot capture ops with variable arguments or returns");
|
||||
}
|
||||
|
||||
MlirContext context = funcBuilder->getContext();
|
||||
MlirLocation loc = getCurrentLocation();
|
||||
auto kernelName = schema.name();
|
||||
TracedSchemaOpBuilder opBuilder{*this, context, loc, opHandle};
|
||||
|
||||
// Map arguments to operands.
|
||||
// This must be accumulated into the OperationState prior to re-dispatch
|
||||
// since the stack is modified at that point.
|
||||
size_t argCount = schema.arguments().size();
|
||||
assert(stack->size() >= argCount && "stack too short");
|
||||
for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) {
|
||||
opBuilder.addOperand(*argIt);
|
||||
}
|
||||
|
||||
// Invoke the original kernel.
|
||||
redispatchCallback();
|
||||
|
||||
// Map returns to results.
|
||||
size_t returnCount = schema.returns().size();
|
||||
assert(stack->size() >= returnCount && "stack too short");
|
||||
for (auto returnIt = stack->end() - returnCount; returnIt != stack->end();
|
||||
++returnIt) {
|
||||
opBuilder.addResult(*returnIt);
|
||||
}
|
||||
|
||||
opBuilder.create();
|
||||
}
|
||||
|
||||
MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||
const IValue &ival) {
|
||||
if (ival.isScalar()) {
|
||||
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
||||
}
|
||||
if (ival.isTensor()) {
|
||||
auto tensor = ival.toTensor();
|
||||
if (!tensor.defined()) {
|
||||
// Optional tensors ("Tensor?" type) are represented as Tensor ivals
|
||||
// that are undefined.
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
|
||||
// Is it an already mapped tensor?
|
||||
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
||||
if (!mlirValueIsNull(mappedValue)) {
|
||||
return mappedValue;
|
||||
}
|
||||
|
||||
mappedValue = importTensorByValue(ival.toTensor());
|
||||
assert(mappedValue.ptr);
|
||||
return mappedValue;
|
||||
}
|
||||
if (ival.isBool()) {
|
||||
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||
return funcBuilder->getBoolConstant(loc, ival.toBool());
|
||||
}
|
||||
if (ival.isList()) {
|
||||
auto list = ival.toList();
|
||||
std::vector<MlirValue> elements;
|
||||
for (IValue element : list) {
|
||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||
}
|
||||
return funcBuilder->buildList(loc,
|
||||
typeMapper.mapFromTorchType(loc, list.elementType()), elements);
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
if (ival.isDevice()) {
|
||||
// TODO: Do we need to model/preserve device? Currently, just None'ing
|
||||
// it out.
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
return {nullptr};
|
||||
// TODO: Implement mappings for the whole set (relevant to this use case):
|
||||
// _(Tensor)
|
||||
// _(Double)
|
||||
// _(Int)
|
||||
// _(Tuple)
|
||||
// _(String)
|
||||
// _(Blob)
|
||||
// _(GenericList)
|
||||
// _(GenericDict)
|
||||
// _(Future)
|
||||
// _(Device)
|
||||
// _(Object)
|
||||
// _(PyObject)
|
||||
// _(Uninitialized)
|
||||
// _(Capsule)
|
||||
// _(RRef)
|
||||
// _(Generator)
|
||||
}
|
||||
|
||||
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||
const IValue &ival) {
|
||||
if (ival.isScalar()) {
|
||||
return typeMapper.mapFromTorchScalarType(ival.toScalar().type());
|
||||
}
|
||||
if (ival.isTensor()) {
|
||||
return typeMapper.forwardTensorToType(ival.toTensor());
|
||||
}
|
||||
if (ival.isBool()) {
|
||||
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
}
|
||||
if (ival.isList()) {
|
||||
return torchMlirTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
if (ival.isDevice()) {
|
||||
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||
auto loc = getCurrentLocation();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||
funcBuilder->mapTensor(tensor, tensorValue);
|
||||
return tensorValue;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
|
||||
// PyTorch logs a warning when kernels are overriden, which is unavoidable
|
||||
// for factory-function BackendSelect kernels (there is not yet a "safe"
|
||||
// override mechanism). So, just silence it. Any of them here are coded
|
||||
// to be a superset of the default functionality, and there are only a few.
|
||||
auto orig_log_level = FLAGS_caffe2_log_level;
|
||||
FLAGS_caffe2_log_level = c10::GLOG_ERROR;
|
||||
|
||||
// Disable capture of arange: causes it to memorize the resulting tensor.
|
||||
m.impl("arange", &AcapController::arangeBackendSelectKernel);
|
||||
|
||||
// Restore log level.
|
||||
FLAGS_caffe2_log_level = orig_log_level;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||
&AcapController::fallbackKernel>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
|
||||
m.impl("copy_", &AcapController::copyUnderKernel);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
|
||||
// The at::convolution op is special in several ways. First, it presently
|
||||
// does not support boxing, so all of the usual fanciness does not apply
|
||||
// and it cannot be intercepted by generic fallthroughs, which is what
|
||||
// would usually allow us to avoid intercepting it at the gradient phase.
|
||||
// Second, the default implementation (see
|
||||
// aten/src/ATen/native/Convolution.cpp) is very switchy based on hard-coded
|
||||
// assumptions about device type. If we do nothing here, we will at best
|
||||
// intercept an mkldnn_convolution, cudnn_convolution, etc on the backend
|
||||
// dispatch keys. Non standard backends that don't have these switches
|
||||
// just route to aten::convolution_overrideable (see the else in
|
||||
// aten::convolution) as a convenience, but that is mostly a pass-through
|
||||
// (except for 3d convolutions which contain a trailing squeeze that needs
|
||||
// special casing). Therefore, we just intercept the aten::convolution op,
|
||||
// record it specially, and then mask ourselves off and ask the CPU backend
|
||||
// to invoke it. Not awesome.
|
||||
// Presumably this is on someone's list to adapt to the dispatch machinery
|
||||
// in a more appropriate way, but as the core of what the framework is,
|
||||
// perhaps people are reticent to touch it. Maybe someday, this can go away.
|
||||
m.impl("convolution", &AcapController::convolutionKernel);
|
||||
|
||||
// Sadly, there is no easy intercept point for the backwards convolution
|
||||
// kernel which allows for chaining to an existing backend. And convolution
|
||||
// is exceptionally special cased in this way, moreso than other ops.
|
||||
// The "solution" is to intercept the backend specific backward convolution
|
||||
// ops, emit it with the signature of the more generic
|
||||
// "convolution_backward_overrideable" op, which is available for generic
|
||||
// backends, and twiddle the parameters needed to get back to that form.
|
||||
// For MKL, which is effectively the CPU implementation, this just means that
|
||||
// some parameters are swapped and the full generality is not supported.
|
||||
// The "right" answer at some point is probably just to implement a
|
||||
// convolution kernel that fully does what is needed and delegates to an
|
||||
// appropriate implementation behind the scenes.
|
||||
m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward);
|
||||
}
|
|
@ -1,131 +0,0 @@
|
|||
//===- acap_dispatch.h ------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// "ATen Capture" dispatcher: Defines facility for capturing programs by
|
||||
// registering dispatch keys to intercept op execution.
|
||||
// References:
|
||||
// http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_ACAP_DISPATCH_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_ACAP_DISPATCH_H
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Main entry point for managing device capture.
|
||||
class AcapController : public std::enable_shared_from_this<AcapController> {
|
||||
public:
|
||||
AcapController(TypeMapper &typeMapper,
|
||||
std::unique_ptr<FuncBuilder> funcBuilder)
|
||||
: typeMapper(typeMapper), funcBuilder(std::move(funcBuilder)) {}
|
||||
|
||||
// Enter and exit the context manager.
|
||||
pybind11::object contextEnter();
|
||||
void contextExit(pybind11::object exc_type, pybind11::object exc_val,
|
||||
pybind11::object exc_tb);
|
||||
|
||||
// Terminates capture and returns tensors from the function.
|
||||
void returns(std::vector<at::Tensor> tensors);
|
||||
|
||||
// Returns the current AcapController (if it has been activated on this
|
||||
// thread. Returns nullptr if none (not active on the current thread).
|
||||
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
|
||||
|
||||
// The fallback boxed kernel that we route captured dispatches through.
|
||||
static void fallbackKernel(const c10::OperatorHandle &opHandle,
|
||||
c10::Stack *stack);
|
||||
|
||||
// Kernel implementation for the boxing-incompatible convolution kernel.
|
||||
static at::Tensor
|
||||
convolutionKernel(const at::Tensor &input, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias,
|
||||
const at::IntArrayRef stride, const at::IntArrayRef padding,
|
||||
const at::IntArrayRef dilation, const bool transposed,
|
||||
const at::IntArrayRef output_padding, const int64_t groups);
|
||||
|
||||
// Kernel implementation for the boxing-incompatible convolution kernel.
|
||||
static std::tuple<at::Tensor, at::Tensor, at::Tensor> mklConvolutionBackward(
|
||||
const at::Tensor &input, const at::Tensor &grad_output,
|
||||
const at::Tensor &weight, const at::IntArrayRef padding,
|
||||
const at::IntArrayRef stride, const at::IntArrayRef dilation,
|
||||
const int64_t groups, std::array<bool, 3> output_mask);
|
||||
|
||||
// Implementation for the aten::copy_ kernel.
|
||||
static at::Tensor ©UnderKernel(at::Tensor &self, const at::Tensor &src,
|
||||
bool non_blocking);
|
||||
|
||||
// Backend select kernel for arange factory function.
|
||||
static at::Tensor arangeBackendSelectKernel(
|
||||
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory);
|
||||
|
||||
private:
|
||||
/// Builds an MLIR operation for a Torch operator step by step.
|
||||
class TracedSchemaOpBuilder {
|
||||
public:
|
||||
TracedSchemaOpBuilder(AcapController &parent, MlirContext context,
|
||||
MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle);
|
||||
void addOperand(const c10::IValue &value);
|
||||
void addResult(const c10::IValue &result);
|
||||
MlirOperation create();
|
||||
|
||||
private:
|
||||
AcapController &parent;
|
||||
MlirLocation loc;
|
||||
const c10::OperatorHandle &opHandle;
|
||||
std::vector<MlirValue> operands;
|
||||
std::vector<MlirType> resultTypes;
|
||||
int resultCount = 0;
|
||||
std::vector<std::pair<size_t, at::Tensor>> resultIndexToTensorMap;
|
||||
};
|
||||
|
||||
MlirLocation getCurrentLocation();
|
||||
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
||||
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
||||
c10::Stack *stack,
|
||||
std::function<void()> redispatchCallback);
|
||||
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
|
||||
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
|
||||
/// Imports a tensor by value (as a constant), remembering the association.
|
||||
MlirValue importTensorByValue(at::Tensor tensor);
|
||||
void verifyHasNotReturned();
|
||||
struct Activation {
|
||||
Activation(std::shared_ptr<AcapController> controller)
|
||||
: controller(std::move(controller)) {}
|
||||
std::shared_ptr<AcapController> controller;
|
||||
// The RAII dispatch key guard is not movable, so heap allocate it. This is
|
||||
// a bit outside of its intended design, but since this is thread local as
|
||||
// well, it should be fine.
|
||||
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> includeGuard;
|
||||
std::unique_ptr<c10::impl::ExcludeDispatchKeyGuard> excludeGuard;
|
||||
};
|
||||
// Gets the thread local stack of active acap controllers.
|
||||
static std::list<Activation> &getThreadLocalActiveStack();
|
||||
|
||||
TypeMapper &typeMapper;
|
||||
std::unique_ptr<FuncBuilder> funcBuilder;
|
||||
bool hasReturned = false;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRPLUGIN_CSRC_C10_DISPATCH_ACAP_DISPATCH_H
|
|
@ -1,29 +0,0 @@
|
|||
//===- debug.cpp ------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
static bool debugTraceToStderrEnabled = false;
|
||||
|
||||
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||
/// a no-op.
|
||||
bool isDebugTraceEnabled() { return debugTraceToStderrEnabled; }
|
||||
|
||||
/// Writes a message to the debug trace log.
|
||||
void debugTrace(const std::string &message) {
|
||||
if (debugTraceToStderrEnabled)
|
||||
std::cerr << "TORCH_MLIR TRACE: " << message << "\n" << std::flush;
|
||||
}
|
||||
|
||||
/// Enables writing debug traces to stderr.
|
||||
void enableDebugTraceToStderr() { debugTraceToStderrEnabled = true; }
|
||||
|
||||
} // namespace torch_mlir
|
|
@ -1,27 +0,0 @@
|
|||
//===- debug.h --------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||
/// a no-op.
|
||||
bool isDebugTraceEnabled();
|
||||
|
||||
/// Writes a message to the debug trace log.
|
||||
void debugTrace(const std::string &message);
|
||||
|
||||
/// Enables writing debug traces to stderr.
|
||||
void enableDebugTraceToStderr();
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H
|
|
@ -1,143 +0,0 @@
|
|||
//===- func_builder.cpp ---------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "op_builder.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
std::unique_ptr<FuncBuilder>
|
||||
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||
MlirLocation location, const std::string &name,
|
||||
std::vector<MlirType> &inputTypes) {
|
||||
auto context = mlirLocationGetContext(location);
|
||||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||
// (this is fragile and reveals details that are not guaranteed).
|
||||
std::vector<MlirNamedAttribute> funcAttrs;
|
||||
funcAttrs.push_back(toMlirNamedAttribute(
|
||||
"type", mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context, inputTypes.size(), inputTypes.data(),
|
||||
/*numResults=*/0, /*results=*/nullptr))));
|
||||
funcAttrs.push_back(toMlirNamedAttribute(
|
||||
"sym_name", mlirStringAttrGet(
|
||||
context, mlirStringRefCreate(name.data(), name.size()))));
|
||||
|
||||
MlirOperationState state =
|
||||
mlirOperationStateGet(toMlirStringRef("builtin.func"), location);
|
||||
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
||||
{
|
||||
// Don't access these once ownership transferred.
|
||||
MlirRegion newBodyRegion = mlirRegionCreate();
|
||||
MlirBlock newEntryBlock =
|
||||
mlirBlockCreate(inputTypes.size(), inputTypes.data());
|
||||
mlirRegionInsertOwnedBlockAfter(newBodyRegion, {nullptr}, newEntryBlock);
|
||||
mlirOperationStateAddOwnedRegions(&state, 1, &newBodyRegion);
|
||||
}
|
||||
|
||||
// Need to re-lookup the region/block because we relinquished ownership above.
|
||||
MlirOperation funcOp = mlirOperationCreate(&state);
|
||||
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
|
||||
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
|
||||
|
||||
inserter(funcOp);
|
||||
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
|
||||
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
|
||||
}
|
||||
|
||||
void FuncBuilder::rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes) {
|
||||
// Get inputs from current function type.
|
||||
MlirAttribute funcTypeAttr =
|
||||
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
||||
assert(!mlirAttributeIsNull(funcTypeAttr) &&
|
||||
"function missing 'type' attribute");
|
||||
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
||||
"function type is not a TypeAttr");
|
||||
MlirType funcType = mlirTypeAttrGetValue(funcTypeAttr);
|
||||
std::vector<MlirType> inputTypes;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(funcType); i < e; ++i) {
|
||||
inputTypes.push_back(mlirFunctionTypeGetInput(funcType, i));
|
||||
}
|
||||
|
||||
// Make new function type.
|
||||
MlirType newFuncType =
|
||||
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||
resultTypes.size(), resultTypes.data());
|
||||
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
|
||||
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
|
||||
newFuncTypeAttr);
|
||||
(void)newFuncTypeAttr;
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::insertConstantOp(MlirOperation op) {
|
||||
mlirBlockInsertOwnedOperationAfter(entryBlock.getBlock(), prevConstantOp, op);
|
||||
prevConstantOp = op;
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::lookupTensor(at::Tensor tensor) {
|
||||
for (auto it = tensorValueMap.rbegin(), e = tensorValueMap.rend(); it != e;
|
||||
++it) {
|
||||
if (it->first.is_same(tensor))
|
||||
return it->second;
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
||||
// Note that interpreter "scalars" match the Python semantics and are
|
||||
// represented as one of double or int64_t, with a special tag for whether
|
||||
// it should be interpreted as a bool.
|
||||
if (s.isIntegral(/*includeBool=*/false)) {
|
||||
MlirType t = torchMlirTorchIntTypeGet(context);
|
||||
MlirAttribute value =
|
||||
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
|
||||
MlirOperation op = createMlirOperation(
|
||||
"torch.constant.int", loc, t, toMlirNamedAttribute("value", value));
|
||||
insertConstantOp(op);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
if (s.isFloatingPoint()) {
|
||||
MlirType t = torchMlirTorchFloatTypeGet(context);
|
||||
MlirAttribute value = mlirFloatAttrDoubleGet(
|
||||
context, mlirF64TypeGet(context), s.to<double>());
|
||||
MlirOperation op = createMlirOperation(
|
||||
"torch.constant.float", loc, t, toMlirNamedAttribute("value", value));
|
||||
insertConstantOp(op);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
if (s.isBoolean()) {
|
||||
return getBoolConstant(loc, s.to<bool>());
|
||||
}
|
||||
// TODO: s.isComplex()
|
||||
|
||||
throw std::invalid_argument("TODO: Scalar of unknown kind");
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
|
||||
return insertConstantOp(OpBuilder(context).createBoolConstant(loc, v));
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
|
||||
return insertConstantOp(OpBuilder(context).createNoneConstant(loc));
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements) {
|
||||
MlirType resultType = torchMlirTorchListTypeGet(elementType);
|
||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
MlirOperation op = state.createOperation();
|
||||
entryBlock.insertBeforeTerminator(op);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
//===- func_builder.h -------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H
|
||||
|
||||
#include "mlir_utils.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Wraps an MlirOperationState, deallocating it on destruction unless if
|
||||
/// it has been consumed.
|
||||
class OperationStateHolder {
|
||||
public:
|
||||
OperationStateHolder(const char *name, MlirLocation loc)
|
||||
: state(mlirOperationStateGet(toMlirStringRef(name), loc)) {}
|
||||
OperationStateHolder(const OperationStateHolder &) = delete;
|
||||
OperationStateHolder(OperationStateHolder &&other) = delete;
|
||||
~OperationStateHolder() {
|
||||
if (owned) {
|
||||
// Destroying is done by creating and then destroying the operation.
|
||||
mlirOperationDestroy(createOperation());
|
||||
}
|
||||
}
|
||||
|
||||
operator MlirOperationState *() { return &state; }
|
||||
|
||||
MlirOperation createOperation() {
|
||||
assert(owned && "cannot createOperation on unowned state");
|
||||
owned = false;
|
||||
return mlirOperationCreate(&state);
|
||||
}
|
||||
|
||||
private:
|
||||
MlirOperationState state;
|
||||
bool owned = true;
|
||||
};
|
||||
|
||||
/// Wraps an MlirBlock under construction, primarily tracking the terminator
|
||||
/// and supporting manipulation of it. The terminator may be null if it has
|
||||
/// not yet been constructed.
|
||||
class BlockBuilder {
|
||||
public:
|
||||
BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn)
|
||||
: block(block), terminator(terminator), isReturn(isReturn) {}
|
||||
|
||||
MlirBlock getBlock() { return block; }
|
||||
MlirOperation getTerminator() { return terminator; }
|
||||
bool getIsReturnTerminator() { return isReturn; }
|
||||
|
||||
/// Inserts an owned operation before the terminator.
|
||||
void insertBeforeTerminator(MlirOperation op) {
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
|
||||
}
|
||||
|
||||
private:
|
||||
MlirBlock block;
|
||||
MlirOperation terminator;
|
||||
bool isReturn;
|
||||
};
|
||||
|
||||
/// Wraps a 'func' MlirOperation and provides facilities for constructing
|
||||
/// IR from some stream of Torch operations.
|
||||
class FuncBuilder {
|
||||
public:
|
||||
/// Callback for inserting a function.
|
||||
using Inserter = std::function<void(MlirOperation funcOp)>;
|
||||
|
||||
/// Creates a new func op with the given characteristics. The created
|
||||
/// operation is not attached. The caller must either destroy it or add it
|
||||
/// to a parent.
|
||||
static std::unique_ptr<FuncBuilder>
|
||||
createFunction(Inserter &inserter, MlirLocation location,
|
||||
const std::string &name, std::vector<MlirType> &inputTypes);
|
||||
|
||||
MlirContext getContext() { return context; }
|
||||
MlirOperation getFuncOp() { return funcOp; }
|
||||
|
||||
/// Gets the function's entry block.
|
||||
MlirBlock getEntryBlock() { return entryBlock.getBlock(); }
|
||||
BlockBuilder &getEntryBlockBuilder() { return entryBlock; }
|
||||
|
||||
/// Rewrites the function's signature to return the given types. It is
|
||||
/// assumed that a compatible terminator has been added.
|
||||
void rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes);
|
||||
|
||||
/// Maps a live Tensor to an MlirValue.
|
||||
void mapTensor(at::Tensor tensor, MlirValue value) {
|
||||
tensorValueMap.push_back(std::make_pair(tensor, value));
|
||||
}
|
||||
|
||||
/// Looks up a current mapping of tensor to an MlirValue, returning a null
|
||||
/// value if not found.
|
||||
MlirValue lookupTensor(at::Tensor tensor);
|
||||
|
||||
/// Gets a scalar constant value.
|
||||
MlirValue getScalarConstant(MlirLocation loc, at::Scalar s);
|
||||
|
||||
/// Gets a bool constant value.
|
||||
MlirValue getBoolConstant(MlirLocation loc, bool v);
|
||||
|
||||
/// Gets a None constant value.
|
||||
MlirValue getNoneConstant(MlirLocation loc);
|
||||
|
||||
/// Builds a list with the given elements
|
||||
MlirValue buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements);
|
||||
|
||||
private:
|
||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||
BlockBuilder entryBlock)
|
||||
: context(context), funcOp(funcOp), entryBlock(std::move(entryBlock)) {
|
||||
(void)this->context;
|
||||
}
|
||||
|
||||
/// Inserts a constant op into the function, returning its first result.
|
||||
/// The function maintains a constant area at the top where these are
|
||||
/// inserted.
|
||||
MlirValue insertConstantOp(MlirOperation op);
|
||||
|
||||
MlirContext context;
|
||||
|
||||
/// The func op under construction.
|
||||
MlirOperation funcOp;
|
||||
|
||||
/// Block builder for the entry block.
|
||||
BlockBuilder entryBlock;
|
||||
|
||||
/// Previously inserted constant op or null.
|
||||
MlirOperation prevConstantOp = {nullptr};
|
||||
|
||||
/// Maps tensors to MlirValue. Unfortunately, this needs to be a linear scan
|
||||
/// because the impl pointer for the Tensor is not accessible. To make this
|
||||
/// slightly better, we add to the back and lookup in reverse under the idea
|
||||
/// that tensors may be mapped and accessed in proximity.
|
||||
/// TODO: Tensors referenced via an IValue support hash code lookup and
|
||||
/// identity checks. Switch to this instead of a linear scan.
|
||||
std::vector<std::pair<at::Tensor, MlirValue>> tensorValueMap;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H
|
|
@ -11,7 +11,6 @@
|
|||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
#include "node_importer.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "ivalue_importer.h"
|
||||
#include "class_annotator.h"
|
||||
#include "function_importer.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
#include "class_annotator.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
|
|
@ -110,7 +110,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
context(castPythonObjectToMlirContext(this->contextObj)),
|
||||
module(createEmptyModule(this->context)),
|
||||
moduleObj(castMlirModuleToPythonObject(module)),
|
||||
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
|
||||
unknownLoc(mlirLocationUnknownGet(context)) {
|
||||
// TODO: Rework this once dialect registration C-APIs are in place.
|
||||
// https://reviews.llvm.org/D88162
|
||||
mlirRegisterAllDialects(context);
|
||||
|
@ -122,30 +122,6 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||
}
|
||||
|
||||
std::shared_ptr<AcapController>
|
||||
ModuleBuilder::startCaptureFunction(std::string &name,
|
||||
std::vector<at::Tensor> args) {
|
||||
// TODO: Verify that arguments do not alias each other.
|
||||
std::vector<MlirType> inputTypes;
|
||||
for (auto &arg : args) {
|
||||
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
|
||||
}
|
||||
|
||||
// TODO: Extract a traceback and use in place of unknownLoc.
|
||||
auto inserter = createInserter();
|
||||
auto funcBuilder =
|
||||
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
|
||||
// Map block arguments.
|
||||
MlirBlock entryBlock = funcBuilder->getEntryBlock();
|
||||
assert(mlirBlockGetNumArguments(entryBlock) ==
|
||||
static_cast<intptr_t>(args.size()) &&
|
||||
"entry block incorrect arg arity");
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i));
|
||||
}
|
||||
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
||||
}
|
||||
|
||||
torch::jit::StrongFunctionPtr
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
MlirBlock block = getBodyBlock();
|
||||
|
@ -166,14 +142,6 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
|||
mlirModuleGetContext(module), *classAnnotator);
|
||||
}
|
||||
|
||||
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
return [=](MlirOperation op) {
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
|
||||
};
|
||||
}
|
||||
|
||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
||||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||
|
@ -184,8 +152,6 @@ void ModuleBuilder::bind(py::module &m) {
|
|||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
||||
py::keep_alive<0, 1>())
|
||||
.def("import_function", &ModuleBuilder::importFunction)
|
||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||
py::arg("classAnnotator") = py::none());
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
|
||||
#include "../pybind.h"
|
||||
|
||||
#include "acap_dispatch.h"
|
||||
#include "class_annotator.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
@ -34,10 +33,6 @@ public:
|
|||
pybind11::object getContextObj() { return contextObj; }
|
||||
pybind11::object getModuleObj() { return moduleObj; }
|
||||
|
||||
// Starts a device-capture based function.
|
||||
std::shared_ptr<AcapController>
|
||||
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
|
||||
|
||||
// Imports a traced function. Note that the python type
|
||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||
// Just a bit of naming cruft.
|
||||
|
@ -52,7 +47,6 @@ public:
|
|||
py::object maybeClassAnnotator);
|
||||
|
||||
private:
|
||||
FuncBuilder::Inserter createInserter();
|
||||
MlirBlock getBodyBlock();
|
||||
|
||||
// Capture references to the python-owned context and module. Ownership
|
||||
|
@ -64,8 +58,6 @@ private:
|
|||
pybind11::object moduleObj;
|
||||
MlirOperation terminator;
|
||||
MlirLocation unknownLoc;
|
||||
|
||||
TypeMapper typeMapper;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
|
|
@ -6,12 +6,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "debug.h"
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include "../init_python_bindings.h"
|
||||
#include "acap_dispatch.h"
|
||||
#include "module_builder.h"
|
||||
#include "class_annotator.h"
|
||||
|
||||
|
@ -125,16 +123,7 @@ py::list GetRegisteredOps() {
|
|||
} // namespace
|
||||
|
||||
void torch_mlir::InitBuilderBindings(py::module &m) {
|
||||
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
|
||||
|
||||
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
||||
"AcapController")
|
||||
.def("__enter__", &AcapController::contextEnter)
|
||||
.def("__exit__", &AcapController::contextExit)
|
||||
.def("returns", &AcapController::returns);
|
||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
||||
|
||||
ModuleBuilder::bind(m);
|
||||
|
||||
initClassAnnotatorBindings(m);
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch_to_mlir_utils.h"
|
||||
#include "function_importer.h"
|
||||
#include "ivalue_importer.h"
|
||||
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("arange_test", []) as f:
|
||||
x = torch.arange(10)
|
||||
f.returns([x])
|
||||
|
||||
# CHECK: %[[T:.*]] = torch.tensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
|
||||
# CHECK: return %[[T]]
|
||||
mb.module.operation.print()
|
|
@ -1,52 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
# XFAIL: *
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_mlir
|
||||
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, Cin, Cout):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
model = Net(Cin, Cout)
|
||||
|
||||
inputs = torch.ones((N,Cin,h,w))
|
||||
loss = torch.nn.NLLLoss()
|
||||
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("resa", [inputs, target]) as f:
|
||||
result = loss(model(inputs), target)
|
||||
result.backward()
|
||||
f.returns([result] + [p.grad for p in model.parameters()])
|
||||
|
||||
# CHECK: torch.operator "aten.convolution"
|
||||
# CHECK: torch.operator "aten._log_softmax"
|
||||
# CHECK: %[[FWD:.*]]:2 = torch.operator "aten.nll_loss2d_forward"
|
||||
# CHECK: torch.operator "aten.nll_loss2d_backward"
|
||||
# CHECK: torch.operator "aten._log_softmax_backward_data"
|
||||
# CHECK: %[[BWD_CONV:.*]]:3 = torch.operator "aten.convolution_backward_overrideable"
|
||||
# CHECK: %[[BWD_CONV_WEIGHTS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#1
|
||||
# CHECK: %[[BWD_CONV_BIAS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#2
|
||||
# CHECK: return %[[FWD]]#0, %[[BWD_CONV_WEIGHTS]], %[[BWD_CONV_BIAS]]
|
||||
mb.module.operation.print(large_elements_limit=2)
|
|
@ -1,33 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
t0 = torch.randn((1,2,3,4))
|
||||
t1 = torch.randn((1,2,3,4))
|
||||
t2 = torch.randn((1,2,3,4))
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("add3", [t0, t1, t2]) as f:
|
||||
t3 = t0 + t1 + t2
|
||||
f.returns([t3])
|
||||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @add3(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[1,2,3,4],f32>, %[[VAL_1:.*]]: !torch.tensor<[1,2,3,4],f32>,
|
||||
# CHECK-SAME: %[[VAL_2:.*]]: !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32> {
|
||||
# CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_5:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_6:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_7:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_6]], %[[VAL_2]], %[[VAL_4]], %[[VAL_7]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: return %[[VAL_8]] : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: }
|
||||
|
||||
print(mb.module)
|
|
@ -1,25 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
ones = torch.ones(42,123,4,5)
|
||||
|
||||
with mb.capture_function("bn2d", [ones]) as f:
|
||||
model = torch.nn.BatchNorm2d(123)
|
||||
result = model(ones)
|
||||
f.returns([result])
|
||||
|
||||
# TODO: This test exercises promotion of const to arrays, inplace zero_ and
|
||||
# add, all of which should be checked individually because they have specific
|
||||
# behavior.
|
||||
# CHECK-LABEL: @bn2d
|
||||
# CHECK: %[[RESULT:.*]]:3 = torch.operator "aten.native_batch_norm"(%arg0
|
||||
# CHECK: return %[[RESULT]]#0 : !torch.tensor<[42,123,4,5],f32>
|
||||
print(mb.module)
|
|
@ -1,47 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_mlir
|
||||
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, Cin, Cout):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||
def forward(self, x):
|
||||
x0 = self.conv1(x)
|
||||
x1 = self.conv1(x)
|
||||
z = torch.cat([x0, x1])
|
||||
output = F.log_softmax(z, dim=1)
|
||||
return output
|
||||
|
||||
model = Net(Cin, Cout)
|
||||
inputs = torch.ones((N,Cin,h,w))
|
||||
loss = torch.nn.NLLLoss()
|
||||
target = torch.empty(2*N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("conv_cat", [inputs, target]) as f:
|
||||
result = loss(model(inputs), target)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK: "aten.convolution"
|
||||
# CHECK: "aten.convolution"
|
||||
# CHECK: torch.prim.ListConstruct
|
||||
# CHECK: "aten._cat"
|
||||
# CHECK: "aten._log_softmax.out"
|
||||
# CHECK: "aten.nll_loss2d_forward"
|
||||
mb.module.operation.print(large_elements_limit=2)
|
|
@ -1,56 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
|
||||
ref_model.weight.data = model.weight.clone()
|
||||
ref_model.bias.data = model.bias.clone()
|
||||
|
||||
softmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.nn.NLLLoss()
|
||||
|
||||
tensor = torch.randn(N, Cin, h, w)
|
||||
|
||||
with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
||||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
# CHECK-LABEL: func @conv2d_fwd(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[3,16,10,10],f32>) -> !torch.tensor<[3,4,8,8],f32> {
|
||||
# CHECK: %[[VAL_1:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_5:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_6:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_7:.*]] = torch.constant.bool false
|
||||
# CHECK: %[[VAL_8:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_9:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_10:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_11:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
|
||||
# CHECK: %[[VAL_12:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
|
||||
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool, !torch.list<!torch.int>, !torch.int) -> !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: }
|
||||
|
||||
mb.module.operation.print(large_elements_limit=2)
|
|
@ -1,28 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
t0 = torch.randn(4)
|
||||
t1 = torch.randn(4)
|
||||
t2 = torch.randn(4)
|
||||
|
||||
with mb.capture_function("multi_output", [t0, t1, t2]) as f:
|
||||
t4 = t0 + t1 + t2
|
||||
t5 = t4 + t1
|
||||
t6 = t5 + t4
|
||||
f.returns([t4, t5, t6])
|
||||
|
||||
# CHECK-LABEL: func @multi_output
|
||||
# CHECK: %[[ADD0:.*]] = torch.operator "aten.add.out"(%arg0
|
||||
# CHECK: %[[ADD1:.*]] = torch.operator "aten.add.out"(%[[ADD0]]
|
||||
# CHECK: %[[ADD2:.*]] = torch.operator "aten.add.out"(%[[ADD1]]
|
||||
# CHECK: %[[ADD3:.*]] = torch.operator "aten.add.out"(%[[ADD2]]
|
||||
# CHECK: return %[[ADD1]], %[[ADD2]], %[[ADD3]]
|
||||
print(mb.module)
|
|
@ -1,30 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt -aten-recognize-kernels -numpy-public-functions-to-tensor -canonicalize | FileCheck %s
|
||||
# TODO: Re-enable after adding support for 4-operand aten::add in `aten-recognize-kernels`.
|
||||
# XFAIL: *
|
||||
|
||||
# TODO: This test should go away or become part of an e2e test suite. It is
|
||||
# preserved right now as a stop-gap.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
t0 = torch.randn((1,4))
|
||||
t1 = torch.randn((4,1))
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("foobar", [t0, t1]) as f:
|
||||
result = t0 + t1
|
||||
f.returns([result])
|
||||
|
||||
# CHECK-LABEL: func @foobar(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4xf32>,
|
||||
# CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
# CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_3:.*]] = "aten.add"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x4xf32>, tensor<4x1xf32>, !torch.int) -> tensor<4x4xf32>
|
||||
# CHECK: return %[[VAL_3]] : tensor<4x4xf32>
|
||||
# CHECK: }
|
||||
mb.module.operation.print(large_elements_limit=2)
|
Loading…
Reference in New Issue