Work around various PyTorch issues in support of convolution.

* Enables the conv2d fwd test and ResA (which are both small).
* Deletes resnet18 and vgg, which both run but generate output that crashes FileCheck and lit (or at least makes them take an eternity).
pull/90/head
Stella Laurenzo 2020-10-18 21:32:29 -07:00
parent 029815152e
commit 58adb6bd8e
8 changed files with 226 additions and 150 deletions

View File

@ -23,6 +23,7 @@ using namespace torch_mlir;
namespace py = pybind11;
using c10::FunctionSchema;
using c10::IValue;
using c10::OperatorHandle;
using c10::Stack;
@ -34,9 +35,67 @@ using c10::Stack;
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
// through the autograd keys.
// https://github.com/llvm/mlir-npcomp/issues/86
// #define ACAP_DISPATCH_KEY AutogradPrivateUse3
#define ACAP_DISPATCH_KEY PrivateUse3
#define ACAP_DISPATCH_KEY PrivateUse2
#define ACAP_GRAD_DISPATCH_KEY AutogradPrivateUse2
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
static c10::DispatchKey kAcapGradDispatchKey =
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
AcapController::KernelCallBuilder::KernelCallBuilder(AcapController &parent,
MlirContext context,
MlirLocation loc,
std::string &kernelName)
: parent(parent), context(context), loc(loc), kernelName(kernelName),
state("torch.kernel_call", loc) {
(void)this->context; // Preserve for future.
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
"kernel_name",
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
}
void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
if (mlirValueIsNull(mlirValue)) {
std::stringstream out;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
mlirOperationStateAddOperands(state, 1, &mlirValue);
}
void AcapController::KernelCallBuilder::addResult(const IValue &value) {
MlirType resultType = parent.mapIValueToMlirType(loc, value);
if (mlirTypeIsNull(resultType)) {
std::stringstream out;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
if (value.isTensor()) {
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
}
mlirOperationStateAddResults(state, 1, &resultType);
}
MlirOperation AcapController::KernelCallBuilder::create() {
// Create operation.
MlirOperation op = state.createOperation();
parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
// Map result tensors.
for (auto &it : resultIndexToTensorMap) {
MlirValue result = mlirOperationGetResult(op, it.first);
parent.funcBuilder->mapTensor(it.second, result);
}
// Add to debug log.
std::stringstream sout;
sout << "CAPTURE: " << kernelName << "\n";
parent.captureLog.push_back(sout.str());
return op;
}
std::list<AcapController::Activation> &
AcapController::getThreadLocalActiveStack() {
@ -48,8 +107,9 @@ py::object AcapController::contextEnter() {
auto &stack = getThreadLocalActiveStack();
stack.emplace_front(shared_from_this());
Activation &current = stack.front();
current.dispatchGuard =
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(kAcapDispatchKey);
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
current.includeGuard =
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(keySet);
return py::cast(this);
}
@ -102,7 +162,8 @@ std::vector<std::string> AcapController::getDebugLog() {
return copy;
}
std::shared_ptr<AcapController> AcapController::getCurrent() {
std::shared_ptr<AcapController>
AcapController::getCurrentThreadAcapController() {
auto &stack = getThreadLocalActiveStack();
if (stack.empty())
return nullptr;
@ -119,7 +180,7 @@ void AcapController::verifyHasNotReturned() {
/* static */
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
Stack *stack) {
auto current = getCurrent();
auto current = getCurrentThreadAcapController();
if (!current) {
current->redispatch(opHandle, stack);
return;
@ -127,6 +188,66 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
current->fallbackKernelImpl(opHandle, stack);
}
at::Tensor AcapController::convolutionKernel(
const at::Tensor &input, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias, const at::IntArrayRef stride,
const at::IntArrayRef padding, const at::IntArrayRef dilation,
const bool transposed, const at::IntArrayRef output_padding,
const int64_t groups) {
static c10::OperatorName opName{"aten::convolution", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find convolution op");
auto opTyped = opHandle->typed<at::Tensor(
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
const bool, const at::IntArrayRef, const int64_t)>();
// Exclude recursive calls: convolution is completely emitted by this
// kernel.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
auto current = getCurrentThreadAcapController();
if (!current) {
return opTyped.callWithDispatchKey(c10::DispatchKey::AutogradOther, input,
weight, bias, stride, padding, dilation,
transposed, output_padding, groups);
}
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
std::string kernelName{"aten::convolution"};
KernelCallBuilder callBuilder{*current, context, loc, kernelName};
callBuilder.addOperand(IValue(input));
callBuilder.addOperand(IValue(weight));
// This is really sad: instead of storing a none in the optional, it stores
// an undefined tensor, which cannot convert to an IValue :(
// TODO: File PyTorch bug. Perhaps this is why they don't support boxing
// for it.
IValue biasIValue;
if (bias && bias->defined()) {
biasIValue = IValue(bias);
} else {
biasIValue = IValue(c10::optional<at::Tensor>());
}
callBuilder.addOperand(biasIValue);
callBuilder.addOperand(IValue(stride));
callBuilder.addOperand(IValue(padding));
callBuilder.addOperand(IValue(dilation));
callBuilder.addOperand(IValue(transposed));
callBuilder.addOperand(IValue(output_padding));
callBuilder.addOperand(IValue(groups));
auto result = opTyped.callWithDispatchKey(
c10::DispatchKey::AutogradOther, input, weight, bias, stride, padding,
dilation, transposed, output_padding, groups);
callBuilder.addResult(result);
callBuilder.create();
return result;
}
MlirLocation AcapController::getCurrentLocation() {
return mlirLocationUnknownGet(funcBuilder->getContext());
}
@ -154,35 +275,19 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
"Cannot capture ops with variable arguments or returns");
}
// TODO: Extract actual location from stack.
MlirContext context = funcBuilder->getContext();
MlirLocation loc = mlirLocationUnknownGet(context);
OperationStateHolder stateHolder("torch.kernel_call", loc);
// Add the kernel_name attribute.
MlirLocation loc = getCurrentLocation();
auto kernelName = schema.name();
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
"kernel_name",
mlirStringAttrGet(context, kernelName.size(), kernelName.data()));
mlirOperationStateAddAttributes(stateHolder, 1, &kernelNameAttr);
KernelCallBuilder callBuilder{*this, context, loc, kernelName};
// Map arguments to operands.
// This must be accumulated into the OperationState prior to re-dispatch
// since the stack is modified at that point.
size_t argCount = schema.arguments().size();
assert(stack->size() >= argCount && "stack too short");
llvm::SmallVector<MlirValue, 4> operands;
for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) {
MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt);
if (mlirValueIsNull(mlirValue)) {
std::stringstream out;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << argIt->tagKind() << "): " << *argIt;
throw std::invalid_argument(out.str());
callBuilder.addOperand(*argIt);
}
operands.push_back(mlirValue);
}
mlirOperationStateAddOperands(stateHolder, operands.size(), operands.data());
// Invoke the original kernel.
redispatch(opHandle, stack);
@ -190,44 +295,16 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
// Map returns to results.
size_t returnCount = schema.returns().size();
assert(stack->size() >= returnCount && "stack too short");
llvm::SmallVector<MlirType, 4> resultTypes;
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
for (auto returnIt = stack->end() - returnCount; returnIt != stack->end();
++returnIt) {
size_t resultIndex = resultTypes.size();
MlirType resultType = mapIValueToMlirType(loc, *returnIt);
if (mlirTypeIsNull(resultType)) {
std::stringstream out;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << returnIt->tagKind() << "): " << *returnIt;
throw std::invalid_argument(out.str());
}
resultTypes.push_back(resultType);
if (returnIt->isTensor()) {
resultIndexToTensorMap.emplace_back(resultIndex, returnIt->toTensor());
}
}
mlirOperationStateAddResults(stateHolder, resultTypes.size(),
resultTypes.data());
// Create operation.
MlirOperation op = stateHolder.createOperation();
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op);
// Map result tensors.
for (auto &it : resultIndexToTensorMap) {
MlirValue result = mlirOperationGetResult(op, it.first);
funcBuilder->mapTensor(it.second, result);
callBuilder.addResult(*returnIt);
}
// Add to debug log.
std::stringstream sout;
sout << "CAPTURE: " << opHandle.schema() << "\n";
captureLog.push_back(sout.str());
callBuilder.create();
}
MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
c10::IValue &ival) {
const IValue &ival) {
if (ival.isScalar()) {
return funcBuilder->getScalarConstant(loc, ival.toScalar());
}
@ -249,7 +326,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
if (ival.isList()) {
auto list = ival.toList();
llvm::SmallVector<MlirValue, 4> elements;
for (c10::IValue element : list) {
for (IValue element : list) {
elements.push_back(mapIValueToMlirValue(loc, element));
}
return funcBuilder->buildConstantList(loc, elements);
@ -278,7 +355,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
}
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
c10::IValue &ival) {
const IValue &ival) {
if (ival.isScalar()) {
return typeMapper.mapScalarType(ival.toScalar().type());
}
@ -376,7 +453,24 @@ TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
&AcapController::fallbackKernel>());
}
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
// The at::convolution op is special in several ways. First, it presently
// does not support boxing, so all of the usual fanciness does not apply
// and it cannot be intercepted by generic fallthroughs, which is what
// would usually allow us to avoid intercepting it at the gradient phase.
// Second, the default implementation (see
// aten/src/ATen/native/Convolution.cpp) is very switchy based on hard-coded
// assumptions about device type. If we do nothing here, we will at best
// intercept an mkldnn_convolution, cudnn_convolution, etc on the backend
// dispatch keys. Non standard backends that don't have these switches
// just route to aten::convolution_overrideable (see the else in
// aten::convolution) as a convenience, but that is mostly a pass-through
// (except for 3d convolutions which contain a trailing squeeze that needs
// special casing). Therefore, we just intercept the aten::convolution op,
// record it specially, and then mask ourselves off and ask the CPU backend
// to invoke it. Not awesome.
// Presumably this is on someone's list to adapt to the dispatch machinery
// in a more appropriate way, but as the core of what the framework is,
// perhaps people are reticent to touch it. Maybe someday, this can go away.
m.impl_UNBOXED("convolution", &AcapController::convolutionKernel);
}

View File

@ -48,20 +48,47 @@ public:
std::vector<std::string> getDebugLog();
// Returns the current AcapController (if it has been activated on this
// thread. Returns nullptr if none.
static std::shared_ptr<AcapController> getCurrent();
// thread. Returns nullptr if none (not active on the current thread).
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
// The fallback boxed kernel that we route captured dispatches through.
static void fallbackKernel(const c10::OperatorHandle &opHandle,
c10::Stack *stack);
// Kernel implementation for the boxing-incompatible convolution kernel.
static at::Tensor
convolutionKernel(const at::Tensor &input, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias,
const at::IntArrayRef stride, const at::IntArrayRef padding,
const at::IntArrayRef dilation, const bool transposed,
const at::IntArrayRef output_padding, const int64_t groups);
private:
/// Builds a kernel call step by step.
class KernelCallBuilder {
public:
KernelCallBuilder(AcapController &parent, MlirContext context,
MlirLocation loc, std::string &kernelName);
void addOperand(const c10::IValue &value);
void addResult(const c10::IValue &result);
MlirOperation create();
private:
AcapController &parent;
MlirContext context;
MlirLocation loc;
std::string &kernelName;
OperationStateHolder state;
int resultCount = 0;
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
};
MlirLocation getCurrentLocation();
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
c10::Stack *stack);
MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival);
MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival);
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
/// Imports a tensor by value (as a constant), remembering the association.
MlirValue importTensorByValue(at::Tensor tensor);
void verifyHasNotReturned();
@ -72,7 +99,8 @@ private:
// The RAII dispatch key guard is not movable, so heap allocate it. This is
// a bit outside of its intended design, but since this is thread local as
// well, it should be fine.
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> dispatchGuard;
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> includeGuard;
std::unique_ptr<c10::impl::ExcludeDispatchKeyGuard> excludeGuard;
};
// Gets the thread local stack of active acap controllers.
static std::list<Activation> &getThreadLocalActiveStack();

View File

@ -11,8 +11,6 @@ import torch.nn.functional as F
import torch_mlir
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
class ResA(nn.Module):
@ -42,18 +40,13 @@ inputs = torch.ones((1,16,128,128))
with mb.capture_function("resa", [inputs]) as f:
f.returns([model(inputs)])
# CHECK-LABEL: func @resa
# TODO: Update checks when test passes to this point.
# CHECK: [[V0:%[a-zA-Z0-9]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"({{.*}}) {layer_name = "L0-native_batch_norm-0"}
# CHECK: [[V1:%[a-zA-Z0-9]+]] = "aten.relu"([[V0]]) {layer_name = "L1-relu-0"}
# CHECK: [[V2:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V1]], {{.*}}) {layer_name = "L2-convolution_overrideable-0"}
# CHECK: [[V3:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V2]]{{.*}}) {layer_name = "L3-native_batch_norm-1"}
# CHECK: [[V4:%[a-zA-Z0-9]+]] = "aten.relu"([[V3]]) {layer_name = "L4-relu-1"}
# CHECK: [[V5:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V4]],{{.*}}) {layer_name = "L5-convolution_overrideable-1"}
# CHECK: [[V6:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V5]],{{.*}}) {layer_name = "L6-native_batch_norm-2"}
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)
# TODO: This isn't a great unit test but checking-in as a lead-in for more
# appropriately factored tests.
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# CHECK-LABEL: func @resa(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[1,16,128,128]:f32>) -> !numpy.ndarray<[1,16,128,128]:f32> {
# CHECK: %[[VAL_118:.*]] = torch.kernel_call "aten::convolution" {{.*}} : (!numpy.ndarray<[1,8,128,128]:f32>, !numpy.ndarray<[16,8,1,1]:f32>, !numpy.ndarray<[16]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: %[[VAL_119:.*]] = torch.kernel_call "aten::add" %{{.*}}, %[[VAL_118]], %{{.*}} : (!numpy.ndarray<[1,16,128,128]:f32>, !numpy.ndarray<[1,16,128,128]:f32>, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: return %[[VAL_119]] : !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: }
print(mb.module)

View File

@ -5,8 +5,6 @@
import torch
import torch_mlir
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
@ -32,10 +30,29 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
result = model(tensor)
f.returns([result])
# CHECK-LABEL: func @conv2d_fwd
# CHECK-SAME: (%arg0: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
# CHECK: %[[P1:.*]] = numpy.create_array_from_tensor %cst : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
# CHECK: %[[P2:.*]] = numpy.create_array_from_tensor %cst_0 : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
# CHECK: %[[R:.*]] = torch.kernel_call "aten::conv2d" %arg0, %[[P1]], %[[P2]], %0, %1, %2, %c1_i64_5 : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
# CHECK: return %[[R]] : !numpy.ndarray<[3,4,8,8]:f32>
# Generated with mlir/utils/generate-test-checks.py
# This is very deterministic and a change test is appropriate.
# CHECK-LABEL: func @conv2d_fwd(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[3,16,10,10]:f32>) -> !numpy.ndarray<[3,4,8,8]:f32> {
# CHECK: %[[VAL_1:.*]] = constant dense<{{.*}}> : tensor<4x16x3x3xf32>
# CHECK: %[[VAL_2:.*]] = constant dense<{{.*}}> : tensor<4xf32>
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
# CHECK: %[[VAL_4:.*]] = constant 1 : i64
# CHECK: %[[VAL_5:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_6:.*]] = constant 0 : i64
# CHECK: %[[VAL_7:.*]] = constant 0 : i64
# CHECK: %[[VAL_8:.*]] = basicpy.build_list %[[VAL_6]], %[[VAL_7]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_9:.*]] = constant 1 : i64
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
# CHECK: %[[VAL_11:.*]] = basicpy.build_list %[[VAL_9]], %[[VAL_10]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_12:.*]] = constant false
# CHECK: %[[VAL_13:.*]] = constant 0 : i64
# CHECK: %[[VAL_14:.*]] = constant 0 : i64
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_13]], %[[VAL_14]] : (i64, i64) -> !basicpy.ListType
# CHECK: %[[VAL_16:.*]] = constant 1 : i64
# CHECK: %[[VAL_17:.*]] = numpy.create_array_from_tensor %[[VAL_1]] : (tensor<4x16x3x3xf32>) -> !numpy.ndarray<[4,16,3,3]:f32>
# CHECK: %[[VAL_18:.*]] = numpy.create_array_from_tensor %[[VAL_2]] : (tensor<4xf32>) -> !numpy.ndarray<[4]:f32>
# CHECK: %[[VAL_19:.*]] = torch.kernel_call "aten::convolution" %[[VAL_0]], %[[VAL_17]], %[[VAL_18]], %[[VAL_5]], %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (!numpy.ndarray<[3,16,10,10]:f32>, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i1, !basicpy.ListType, i64) -> !numpy.ndarray<[3,4,8,8]:f32>
# CHECK: return %[[VAL_19]] : !numpy.ndarray<[3,4,8,8]:f32>
# CHECK: }
print(mb.module)

View File

@ -1,30 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import torchvision.models as models
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# TODO: Pass through npcomp-opt and FileCheck once able to elide large elements.
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
model = models.resnet18()
model.training = False
tensor = torch.randn(32,3,32,32)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("res18", [tensor]) as f:
result = model(tensor)
f.returns([result])
# for now we just check the output shape
# CHECK-LABEL: @res18
# TODO: Add checks once running to this point.
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)

View File

@ -1,28 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import torchvision.models as models
# XFAIL: *
# TODO: https://github.com/llvm/mlir-npcomp/issues/86
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
model = models.vgg11_bn()
model.training = False
inputs = torch.ones(32,3,32,32)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("vgg11", [inputs]) as f:
result = model(inputs)
f.returns([result])
# CHECK-LABEL: func @vgg11
# TODO: Add checks once passing this far.
# TODO: Enable printing once large elements can be elided (crashes lit).
# https://github.com/llvm/mlir-npcomp/issues/87
# print(mb.module)

View File

@ -25,5 +25,5 @@ with mb.capture_function("foobar", [t0, t1]) as f:
# CHECK: }
print(mb.module)
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
# CHECK: CAPTURE: aten::add
for line in f.get_debug_log(): print(line)

View File

@ -98,6 +98,8 @@ def AnyScalar : AnyTypeOf<[
def AnyTorchType : AnyTypeOf<[
AnyScalar,
AnyTorchTensorType,
Basicpy_ListType,
Basicpy_NoneType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_BASE