mirror of https://github.com/llvm/torch-mlir
Work around various PyTorch issues in support of convolution.
* Enables the conv2d fwd test and ResA (which are both small). * Deletes resnet18 and vgg, which both run but generate output that crashes FileCheck and lit (or at least makes them take an eternity).pull/90/head
parent
029815152e
commit
58adb6bd8e
|
@ -23,6 +23,7 @@ using namespace torch_mlir;
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
using c10::FunctionSchema;
|
using c10::FunctionSchema;
|
||||||
|
using c10::IValue;
|
||||||
using c10::OperatorHandle;
|
using c10::OperatorHandle;
|
||||||
using c10::Stack;
|
using c10::Stack;
|
||||||
|
|
||||||
|
@ -34,9 +35,67 @@ using c10::Stack;
|
||||||
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
|
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
|
||||||
// through the autograd keys.
|
// through the autograd keys.
|
||||||
// https://github.com/llvm/mlir-npcomp/issues/86
|
// https://github.com/llvm/mlir-npcomp/issues/86
|
||||||
// #define ACAP_DISPATCH_KEY AutogradPrivateUse3
|
#define ACAP_DISPATCH_KEY PrivateUse2
|
||||||
#define ACAP_DISPATCH_KEY PrivateUse3
|
#define ACAP_GRAD_DISPATCH_KEY AutogradPrivateUse2
|
||||||
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
||||||
|
static c10::DispatchKey kAcapGradDispatchKey =
|
||||||
|
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
|
||||||
|
|
||||||
|
AcapController::KernelCallBuilder::KernelCallBuilder(AcapController &parent,
|
||||||
|
MlirContext context,
|
||||||
|
MlirLocation loc,
|
||||||
|
std::string &kernelName)
|
||||||
|
: parent(parent), context(context), loc(loc), kernelName(kernelName),
|
||||||
|
state("torch.kernel_call", loc) {
|
||||||
|
(void)this->context; // Preserve for future.
|
||||||
|
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
||||||
|
"kernel_name",
|
||||||
|
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
|
||||||
|
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
|
||||||
|
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
|
||||||
|
if (mlirValueIsNull(mlirValue)) {
|
||||||
|
std::stringstream out;
|
||||||
|
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||||
|
<< "' (" << value.tagKind() << "): " << value;
|
||||||
|
throw std::invalid_argument(out.str());
|
||||||
|
}
|
||||||
|
mlirOperationStateAddOperands(state, 1, &mlirValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcapController::KernelCallBuilder::addResult(const IValue &value) {
|
||||||
|
MlirType resultType = parent.mapIValueToMlirType(loc, value);
|
||||||
|
if (mlirTypeIsNull(resultType)) {
|
||||||
|
std::stringstream out;
|
||||||
|
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||||
|
<< "' (" << value.tagKind() << "): " << value;
|
||||||
|
throw std::invalid_argument(out.str());
|
||||||
|
}
|
||||||
|
if (value.isTensor()) {
|
||||||
|
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
|
||||||
|
}
|
||||||
|
mlirOperationStateAddResults(state, 1, &resultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirOperation AcapController::KernelCallBuilder::create() {
|
||||||
|
// Create operation.
|
||||||
|
MlirOperation op = state.createOperation();
|
||||||
|
parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
|
||||||
|
|
||||||
|
// Map result tensors.
|
||||||
|
for (auto &it : resultIndexToTensorMap) {
|
||||||
|
MlirValue result = mlirOperationGetResult(op, it.first);
|
||||||
|
parent.funcBuilder->mapTensor(it.second, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to debug log.
|
||||||
|
std::stringstream sout;
|
||||||
|
sout << "CAPTURE: " << kernelName << "\n";
|
||||||
|
parent.captureLog.push_back(sout.str());
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
std::list<AcapController::Activation> &
|
std::list<AcapController::Activation> &
|
||||||
AcapController::getThreadLocalActiveStack() {
|
AcapController::getThreadLocalActiveStack() {
|
||||||
|
@ -48,8 +107,9 @@ py::object AcapController::contextEnter() {
|
||||||
auto &stack = getThreadLocalActiveStack();
|
auto &stack = getThreadLocalActiveStack();
|
||||||
stack.emplace_front(shared_from_this());
|
stack.emplace_front(shared_from_this());
|
||||||
Activation ¤t = stack.front();
|
Activation ¤t = stack.front();
|
||||||
current.dispatchGuard =
|
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||||
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(kAcapDispatchKey);
|
current.includeGuard =
|
||||||
|
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(keySet);
|
||||||
return py::cast(this);
|
return py::cast(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +162,8 @@ std::vector<std::string> AcapController::getDebugLog() {
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<AcapController> AcapController::getCurrent() {
|
std::shared_ptr<AcapController>
|
||||||
|
AcapController::getCurrentThreadAcapController() {
|
||||||
auto &stack = getThreadLocalActiveStack();
|
auto &stack = getThreadLocalActiveStack();
|
||||||
if (stack.empty())
|
if (stack.empty())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -119,7 +180,7 @@ void AcapController::verifyHasNotReturned() {
|
||||||
/* static */
|
/* static */
|
||||||
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||||
Stack *stack) {
|
Stack *stack) {
|
||||||
auto current = getCurrent();
|
auto current = getCurrentThreadAcapController();
|
||||||
if (!current) {
|
if (!current) {
|
||||||
current->redispatch(opHandle, stack);
|
current->redispatch(opHandle, stack);
|
||||||
return;
|
return;
|
||||||
|
@ -127,6 +188,66 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||||
current->fallbackKernelImpl(opHandle, stack);
|
current->fallbackKernelImpl(opHandle, stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor AcapController::convolutionKernel(
|
||||||
|
const at::Tensor &input, const at::Tensor &weight,
|
||||||
|
const c10::optional<at::Tensor> &bias, const at::IntArrayRef stride,
|
||||||
|
const at::IntArrayRef padding, const at::IntArrayRef dilation,
|
||||||
|
const bool transposed, const at::IntArrayRef output_padding,
|
||||||
|
const int64_t groups) {
|
||||||
|
static c10::OperatorName opName{"aten::convolution", ""};
|
||||||
|
auto &dispatcher = c10::Dispatcher::singleton();
|
||||||
|
auto opHandle = dispatcher.findOp(opName);
|
||||||
|
assert(opHandle && "could not find convolution op");
|
||||||
|
auto opTyped = opHandle->typed<at::Tensor(
|
||||||
|
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
|
||||||
|
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
|
||||||
|
const bool, const at::IntArrayRef, const int64_t)>();
|
||||||
|
|
||||||
|
// Exclude recursive calls: convolution is completely emitted by this
|
||||||
|
// kernel.
|
||||||
|
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||||
|
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||||
|
|
||||||
|
auto current = getCurrentThreadAcapController();
|
||||||
|
if (!current) {
|
||||||
|
return opTyped.callWithDispatchKey(c10::DispatchKey::AutogradOther, input,
|
||||||
|
weight, bias, stride, padding, dilation,
|
||||||
|
transposed, output_padding, groups);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirContext context = current->funcBuilder->getContext();
|
||||||
|
MlirLocation loc = current->getCurrentLocation();
|
||||||
|
std::string kernelName{"aten::convolution"};
|
||||||
|
KernelCallBuilder callBuilder{*current, context, loc, kernelName};
|
||||||
|
|
||||||
|
callBuilder.addOperand(IValue(input));
|
||||||
|
callBuilder.addOperand(IValue(weight));
|
||||||
|
// This is really sad: instead of storing a none in the optional, it stores
|
||||||
|
// an undefined tensor, which cannot convert to an IValue :(
|
||||||
|
// TODO: File PyTorch bug. Perhaps this is why they don't support boxing
|
||||||
|
// for it.
|
||||||
|
IValue biasIValue;
|
||||||
|
if (bias && bias->defined()) {
|
||||||
|
biasIValue = IValue(bias);
|
||||||
|
} else {
|
||||||
|
biasIValue = IValue(c10::optional<at::Tensor>());
|
||||||
|
}
|
||||||
|
callBuilder.addOperand(biasIValue);
|
||||||
|
callBuilder.addOperand(IValue(stride));
|
||||||
|
callBuilder.addOperand(IValue(padding));
|
||||||
|
callBuilder.addOperand(IValue(dilation));
|
||||||
|
callBuilder.addOperand(IValue(transposed));
|
||||||
|
callBuilder.addOperand(IValue(output_padding));
|
||||||
|
callBuilder.addOperand(IValue(groups));
|
||||||
|
|
||||||
|
auto result = opTyped.callWithDispatchKey(
|
||||||
|
c10::DispatchKey::AutogradOther, input, weight, bias, stride, padding,
|
||||||
|
dilation, transposed, output_padding, groups);
|
||||||
|
callBuilder.addResult(result);
|
||||||
|
callBuilder.create();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
MlirLocation AcapController::getCurrentLocation() {
|
MlirLocation AcapController::getCurrentLocation() {
|
||||||
return mlirLocationUnknownGet(funcBuilder->getContext());
|
return mlirLocationUnknownGet(funcBuilder->getContext());
|
||||||
}
|
}
|
||||||
|
@ -154,35 +275,19 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
"Cannot capture ops with variable arguments or returns");
|
"Cannot capture ops with variable arguments or returns");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Extract actual location from stack.
|
|
||||||
MlirContext context = funcBuilder->getContext();
|
MlirContext context = funcBuilder->getContext();
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = getCurrentLocation();
|
||||||
OperationStateHolder stateHolder("torch.kernel_call", loc);
|
|
||||||
|
|
||||||
// Add the kernel_name attribute.
|
|
||||||
auto kernelName = schema.name();
|
auto kernelName = schema.name();
|
||||||
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
KernelCallBuilder callBuilder{*this, context, loc, kernelName};
|
||||||
"kernel_name",
|
|
||||||
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
|
|
||||||
mlirOperationStateAddAttributes(stateHolder, 1, &kernelNameAttr);
|
|
||||||
|
|
||||||
// Map arguments to operands.
|
// Map arguments to operands.
|
||||||
// This must be accumulated into the OperationState prior to re-dispatch
|
// This must be accumulated into the OperationState prior to re-dispatch
|
||||||
// since the stack is modified at that point.
|
// since the stack is modified at that point.
|
||||||
size_t argCount = schema.arguments().size();
|
size_t argCount = schema.arguments().size();
|
||||||
assert(stack->size() >= argCount && "stack too short");
|
assert(stack->size() >= argCount && "stack too short");
|
||||||
llvm::SmallVector<MlirValue, 4> operands;
|
|
||||||
for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) {
|
for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) {
|
||||||
MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt);
|
callBuilder.addOperand(*argIt);
|
||||||
if (mlirValueIsNull(mlirValue)) {
|
|
||||||
std::stringstream out;
|
|
||||||
out << "Unsupported capture value returned from kernel '" << kernelName
|
|
||||||
<< "' (" << argIt->tagKind() << "): " << *argIt;
|
|
||||||
throw std::invalid_argument(out.str());
|
|
||||||
}
|
|
||||||
operands.push_back(mlirValue);
|
|
||||||
}
|
}
|
||||||
mlirOperationStateAddOperands(stateHolder, operands.size(), operands.data());
|
|
||||||
|
|
||||||
// Invoke the original kernel.
|
// Invoke the original kernel.
|
||||||
redispatch(opHandle, stack);
|
redispatch(opHandle, stack);
|
||||||
|
@ -190,44 +295,16 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
// Map returns to results.
|
// Map returns to results.
|
||||||
size_t returnCount = schema.returns().size();
|
size_t returnCount = schema.returns().size();
|
||||||
assert(stack->size() >= returnCount && "stack too short");
|
assert(stack->size() >= returnCount && "stack too short");
|
||||||
llvm::SmallVector<MlirType, 4> resultTypes;
|
|
||||||
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
|
|
||||||
for (auto returnIt = stack->end() - returnCount; returnIt != stack->end();
|
for (auto returnIt = stack->end() - returnCount; returnIt != stack->end();
|
||||||
++returnIt) {
|
++returnIt) {
|
||||||
size_t resultIndex = resultTypes.size();
|
callBuilder.addResult(*returnIt);
|
||||||
MlirType resultType = mapIValueToMlirType(loc, *returnIt);
|
|
||||||
if (mlirTypeIsNull(resultType)) {
|
|
||||||
std::stringstream out;
|
|
||||||
out << "Unsupported capture value returned from kernel '" << kernelName
|
|
||||||
<< "' (" << returnIt->tagKind() << "): " << *returnIt;
|
|
||||||
throw std::invalid_argument(out.str());
|
|
||||||
}
|
|
||||||
resultTypes.push_back(resultType);
|
|
||||||
if (returnIt->isTensor()) {
|
|
||||||
resultIndexToTensorMap.emplace_back(resultIndex, returnIt->toTensor());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mlirOperationStateAddResults(stateHolder, resultTypes.size(),
|
|
||||||
resultTypes.data());
|
|
||||||
|
|
||||||
// Create operation.
|
|
||||||
MlirOperation op = stateHolder.createOperation();
|
|
||||||
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
|
|
||||||
|
|
||||||
// Map result tensors.
|
|
||||||
for (auto &it : resultIndexToTensorMap) {
|
|
||||||
MlirValue result = mlirOperationGetResult(op, it.first);
|
|
||||||
funcBuilder->mapTensor(it.second, result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to debug log.
|
callBuilder.create();
|
||||||
std::stringstream sout;
|
|
||||||
sout << "CAPTURE: " << opHandle.schema() << "\n";
|
|
||||||
captureLog.push_back(sout.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
c10::IValue &ival) {
|
const IValue &ival) {
|
||||||
if (ival.isScalar()) {
|
if (ival.isScalar()) {
|
||||||
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
||||||
}
|
}
|
||||||
|
@ -249,7 +326,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
if (ival.isList()) {
|
if (ival.isList()) {
|
||||||
auto list = ival.toList();
|
auto list = ival.toList();
|
||||||
llvm::SmallVector<MlirValue, 4> elements;
|
llvm::SmallVector<MlirValue, 4> elements;
|
||||||
for (c10::IValue element : list) {
|
for (IValue element : list) {
|
||||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||||
}
|
}
|
||||||
return funcBuilder->buildConstantList(loc, elements);
|
return funcBuilder->buildConstantList(loc, elements);
|
||||||
|
@ -278,7 +355,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
c10::IValue &ival) {
|
const IValue &ival) {
|
||||||
if (ival.isScalar()) {
|
if (ival.isScalar()) {
|
||||||
return typeMapper.mapScalarType(ival.toScalar().type());
|
return typeMapper.mapScalarType(ival.toScalar().type());
|
||||||
}
|
}
|
||||||
|
@ -376,7 +453,24 @@ TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||||
&AcapController::fallbackKernel>());
|
&AcapController::fallbackKernel>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
|
TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
|
||||||
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<
|
// The at::convolution op is special in several ways. First, it presently
|
||||||
&AcapController::fallbackKernel>());
|
// does not support boxing, so all of the usual fanciness does not apply
|
||||||
|
// and it cannot be intercepted by generic fallthroughs, which is what
|
||||||
|
// would usually allow us to avoid intercepting it at the gradient phase.
|
||||||
|
// Second, the default implementation (see
|
||||||
|
// aten/src/ATen/native/Convolution.cpp) is very switchy based on hard-coded
|
||||||
|
// assumptions about device type. If we do nothing here, we will at best
|
||||||
|
// intercept an mkldnn_convolution, cudnn_convolution, etc on the backend
|
||||||
|
// dispatch keys. Non standard backends that don't have these switches
|
||||||
|
// just route to aten::convolution_overrideable (see the else in
|
||||||
|
// aten::convolution) as a convenience, but that is mostly a pass-through
|
||||||
|
// (except for 3d convolutions which contain a trailing squeeze that needs
|
||||||
|
// special casing). Therefore, we just intercept the aten::convolution op,
|
||||||
|
// record it specially, and then mask ourselves off and ask the CPU backend
|
||||||
|
// to invoke it. Not awesome.
|
||||||
|
// Presumably this is on someone's list to adapt to the dispatch machinery
|
||||||
|
// in a more appropriate way, but as the core of what the framework is,
|
||||||
|
// perhaps people are reticent to touch it. Maybe someday, this can go away.
|
||||||
|
m.impl_UNBOXED("convolution", &AcapController::convolutionKernel);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,20 +48,47 @@ public:
|
||||||
std::vector<std::string> getDebugLog();
|
std::vector<std::string> getDebugLog();
|
||||||
|
|
||||||
// Returns the current AcapController (if it has been activated on this
|
// Returns the current AcapController (if it has been activated on this
|
||||||
// thread. Returns nullptr if none.
|
// thread. Returns nullptr if none (not active on the current thread).
|
||||||
static std::shared_ptr<AcapController> getCurrent();
|
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
|
||||||
|
|
||||||
// The fallback boxed kernel that we route captured dispatches through.
|
// The fallback boxed kernel that we route captured dispatches through.
|
||||||
static void fallbackKernel(const c10::OperatorHandle &opHandle,
|
static void fallbackKernel(const c10::OperatorHandle &opHandle,
|
||||||
c10::Stack *stack);
|
c10::Stack *stack);
|
||||||
|
|
||||||
|
// Kernel implementation for the boxing-incompatible convolution kernel.
|
||||||
|
static at::Tensor
|
||||||
|
convolutionKernel(const at::Tensor &input, const at::Tensor &weight,
|
||||||
|
const c10::optional<at::Tensor> &bias,
|
||||||
|
const at::IntArrayRef stride, const at::IntArrayRef padding,
|
||||||
|
const at::IntArrayRef dilation, const bool transposed,
|
||||||
|
const at::IntArrayRef output_padding, const int64_t groups);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// Builds a kernel call step by step.
|
||||||
|
class KernelCallBuilder {
|
||||||
|
public:
|
||||||
|
KernelCallBuilder(AcapController &parent, MlirContext context,
|
||||||
|
MlirLocation loc, std::string &kernelName);
|
||||||
|
void addOperand(const c10::IValue &value);
|
||||||
|
void addResult(const c10::IValue &result);
|
||||||
|
MlirOperation create();
|
||||||
|
|
||||||
|
private:
|
||||||
|
AcapController &parent;
|
||||||
|
MlirContext context;
|
||||||
|
MlirLocation loc;
|
||||||
|
std::string &kernelName;
|
||||||
|
OperationStateHolder state;
|
||||||
|
int resultCount = 0;
|
||||||
|
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
|
||||||
|
};
|
||||||
|
|
||||||
MlirLocation getCurrentLocation();
|
MlirLocation getCurrentLocation();
|
||||||
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
||||||
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
||||||
c10::Stack *stack);
|
c10::Stack *stack);
|
||||||
MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival);
|
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
|
||||||
MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival);
|
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
|
||||||
/// Imports a tensor by value (as a constant), remembering the association.
|
/// Imports a tensor by value (as a constant), remembering the association.
|
||||||
MlirValue importTensorByValue(at::Tensor tensor);
|
MlirValue importTensorByValue(at::Tensor tensor);
|
||||||
void verifyHasNotReturned();
|
void verifyHasNotReturned();
|
||||||
|
@ -72,7 +99,8 @@ private:
|
||||||
// The RAII dispatch key guard is not movable, so heap allocate it. This is
|
// The RAII dispatch key guard is not movable, so heap allocate it. This is
|
||||||
// a bit outside of its intended design, but since this is thread local as
|
// a bit outside of its intended design, but since this is thread local as
|
||||||
// well, it should be fine.
|
// well, it should be fine.
|
||||||
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> dispatchGuard;
|
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> includeGuard;
|
||||||
|
std::unique_ptr<c10::impl::ExcludeDispatchKeyGuard> excludeGuard;
|
||||||
};
|
};
|
||||||
// Gets the thread local stack of active acap controllers.
|
// Gets the thread local stack of active acap controllers.
|
||||||
static std::list<Activation> &getThreadLocalActiveStack();
|
static std::list<Activation> &getThreadLocalActiveStack();
|
||||||
|
|
|
@ -11,8 +11,6 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
# XFAIL: *
|
|
||||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
class ResA(nn.Module):
|
class ResA(nn.Module):
|
||||||
|
@ -42,18 +40,13 @@ inputs = torch.ones((1,16,128,128))
|
||||||
with mb.capture_function("resa", [inputs]) as f:
|
with mb.capture_function("resa", [inputs]) as f:
|
||||||
f.returns([model(inputs)])
|
f.returns([model(inputs)])
|
||||||
|
|
||||||
# CHECK-LABEL: func @resa
|
# TODO: This isn't a great unit test but checking-in as a lead-in for more
|
||||||
# TODO: Update checks when test passes to this point.
|
# appropriately factored tests.
|
||||||
# CHECK: [[V0:%[a-zA-Z0-9]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"({{.*}}) {layer_name = "L0-native_batch_norm-0"}
|
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||||
# CHECK: [[V1:%[a-zA-Z0-9]+]] = "aten.relu"([[V0]]) {layer_name = "L1-relu-0"}
|
# CHECK-LABEL: func @resa(
|
||||||
# CHECK: [[V2:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V1]], {{.*}}) {layer_name = "L2-convolution_overrideable-0"}
|
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[1,16,128,128]:f32>) -> !numpy.ndarray<[1,16,128,128]:f32> {
|
||||||
# CHECK: [[V3:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V2]]{{.*}}) {layer_name = "L3-native_batch_norm-1"}
|
# CHECK: %[[VAL_118:.*]] = torch.kernel_call "aten::convolution" {{.*}} : (!numpy.ndarray<[1,8,128,128]:f32>, !numpy.ndarray<[16,8,1,1]:f32>, !numpy.ndarray<[16]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
|
||||||
# CHECK: [[V4:%[a-zA-Z0-9]+]] = "aten.relu"([[V3]]) {layer_name = "L4-relu-1"}
|
# CHECK: %[[VAL_119:.*]] = torch.kernel_call "aten::add" %{{.*}}, %[[VAL_118]], %{{.*}} : (!numpy.ndarray<[1,16,128,128]:f32>, !numpy.ndarray<[1,16,128,128]:f32>, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
|
||||||
# CHECK: [[V5:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V4]],{{.*}}) {layer_name = "L5-convolution_overrideable-1"}
|
# CHECK: return %[[VAL_119]] : !numpy.ndarray<[1,16,128,128]:f32>
|
||||||
# CHECK: [[V6:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V5]],{{.*}}) {layer_name = "L6-native_batch_norm-2"}
|
# CHECK: }
|
||||||
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
|
print(mb.module)
|
||||||
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
|
|
||||||
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
|
|
||||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
|
||||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
|
||||||
# print(mb.module)
|
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
# XFAIL: *
|
|
||||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
@ -32,10 +30,29 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
||||||
result = model(tensor)
|
result = model(tensor)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
# CHECK-LABEL: func @conv2d_fwd
|
# Generated with mlir/utils/generate-test-checks.py
|
||||||
# CHECK-SAME: (%arg0: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
# This is very deterministic and a change test is appropriate.
|
||||||
# CHECK: %[[P1:.*]] = numpy.create_array_from_tensor %cst : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
# CHECK-LABEL: func @conv2d_fwd(
|
||||||
# CHECK: %[[P2:.*]] = numpy.create_array_from_tensor %cst_0 : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
|
||||||
# CHECK: %[[R:.*]] = torch.kernel_call "aten::conv2d" %arg0, %[[P1]], %[[P2]], %0, %1, %2, %c1_i64_5 : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
# CHECK: %[[VAL_1:.*]] = constant dense<{{.*}}> : tensor<4x16x3x3xf32>
|
||||||
# CHECK: return %[[R]] : !numpy.ndarray<[3,4,8,8]:f32>
|
# CHECK: %[[VAL_2:.*]] = constant dense<{{.*}}> : tensor<4xf32>
|
||||||
|
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
|
||||||
|
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
|
||||||
|
# CHECK: %[[VAL_5:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||||
|
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
|
||||||
|
# CHECK: %[[VAL_7:.*]] = constant 0 : i64
|
||||||
|
# CHECK: %[[VAL_8:.*]] = basicpy.build_list %[[VAL_6]], %[[VAL_7]] : (i64, i64) -> !basicpy.ListType
|
||||||
|
# CHECK: %[[VAL_9:.*]] = constant 1 : i64
|
||||||
|
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
|
||||||
|
# CHECK: %[[VAL_11:.*]] = basicpy.build_list %[[VAL_9]], %[[VAL_10]] : (i64, i64) -> !basicpy.ListType
|
||||||
|
# CHECK: %[[VAL_12:.*]] = constant false
|
||||||
|
# CHECK: %[[VAL_13:.*]] = constant 0 : i64
|
||||||
|
# CHECK: %[[VAL_14:.*]] = constant 0 : i64
|
||||||
|
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_13]], %[[VAL_14]] : (i64, i64) -> !basicpy.ListType
|
||||||
|
# CHECK: %[[VAL_16:.*]] = constant 1 : i64
|
||||||
|
# CHECK: %[[VAL_17:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
|
||||||
|
# CHECK: %[[VAL_18:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
|
||||||
|
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_17]], %[[VAL_18]], %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
|
||||||
|
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
|
||||||
|
# CHECK: }
|
||||||
print(mb.module)
|
print(mb.module)
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
# -*- Python -*-
|
|
||||||
# This file is licensed under a pytorch-style license
|
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch_mlir
|
|
||||||
import torchvision.models as models
|
|
||||||
|
|
||||||
# XFAIL: *
|
|
||||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
|
||||||
# TODO: Pass through npcomp-opt and FileCheck once able to elide large elements.
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
|
||||||
|
|
||||||
model = models.resnet18()
|
|
||||||
model.training = False
|
|
||||||
|
|
||||||
tensor = torch.randn(32,3,32,32)
|
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
|
||||||
|
|
||||||
with mb.capture_function("res18", [tensor]) as f:
|
|
||||||
result = model(tensor)
|
|
||||||
f.returns([result])
|
|
||||||
|
|
||||||
# for now we just check the output shape
|
|
||||||
# CHECK-LABEL: @res18
|
|
||||||
# TODO: Add checks once running to this point.
|
|
||||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
|
||||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
|
||||||
# print(mb.module)
|
|
|
@ -1,28 +0,0 @@
|
||||||
# -*- Python -*-
|
|
||||||
# This file is licensed under a pytorch-style license
|
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch_mlir
|
|
||||||
import torchvision.models as models
|
|
||||||
|
|
||||||
# XFAIL: *
|
|
||||||
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
|
||||||
|
|
||||||
model = models.vgg11_bn()
|
|
||||||
model.training = False
|
|
||||||
|
|
||||||
inputs = torch.ones(32,3,32,32)
|
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
|
||||||
|
|
||||||
with mb.capture_function("vgg11", [inputs]) as f:
|
|
||||||
result = model(inputs)
|
|
||||||
f.returns([result])
|
|
||||||
|
|
||||||
# CHECK-LABEL: func @vgg11
|
|
||||||
# TODO: Add checks once passing this far.
|
|
||||||
# TODO: Enable printing once large elements can be elided (crashes lit).
|
|
||||||
# https://github.com/llvm/mlir-npcomp/issues/87
|
|
||||||
# print(mb.module)
|
|
|
@ -25,5 +25,5 @@ with mb.capture_function("foobar", [t0, t1]) as f:
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(mb.module)
|
print(mb.module)
|
||||||
|
|
||||||
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
|
# CHECK: CAPTURE: aten::add
|
||||||
for line in f.get_debug_log(): print(line)
|
for line in f.get_debug_log(): print(line)
|
||||||
|
|
|
@ -98,6 +98,8 @@ def AnyScalar : AnyTypeOf<[
|
||||||
def AnyTorchType : AnyTypeOf<[
|
def AnyTorchType : AnyTypeOf<[
|
||||||
AnyScalar,
|
AnyScalar,
|
||||||
AnyTorchTensorType,
|
AnyTorchTensorType,
|
||||||
|
Basicpy_ListType,
|
||||||
|
Basicpy_NoneType,
|
||||||
], "Any type that is legal to pass to a Torch kernel">;
|
], "Any type that is legal to pass to a Torch kernel">;
|
||||||
|
|
||||||
#endif // TORCH_BASE
|
#endif // TORCH_BASE
|
||||||
|
|
Loading…
Reference in New Issue