Add Conv2D Torchscript Import Support (#167)

Adds support for lowering a torch.nn.Conv2d module to the Torch Dialect through TorchScript import.
Generated IR can be viewed here:
https://gist.github.com/brycearden/6c0f790115c4577249372ef82768e6fd

Required implementing support for tuple in the ivalue importer and list in the node importer.
pull/169/head
Bryce Arden 2021-02-25 14:14:00 -06:00 committed by GitHub
parent a375ccf9da
commit 27a4515de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 0 deletions

View File

@ -253,6 +253,17 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
npcompListTypeGet(context), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTuple()) {
auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> elems;
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
npcompTupleTypeGet(context), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);

View File

@ -123,6 +123,14 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::ListConstruct) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "basicpy.build_list", loc, npcompListTypeGet(context),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::If) {
// TorchScript will already have an explicit op to determine truthiness. So
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.

View File

@ -0,0 +1,25 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @f(
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
# CHECK: return %[[RET]] : !basicpy.ListType
@mb.import_function
@torch.jit.script
def f(t0, t1):
return [t0, t1]
assert isinstance(f, torch.jit.ScriptFunction)
mb.module.operation.print()
print()

View File

@ -0,0 +1,33 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.t = (1, 2)
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# TODO: Don't lose element type.
# CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
# CHECK: %[[TUPLE:.*]] = basicpy.build_tuple %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.TupleType
# CHECK: torch.nn_module {
# CHECK: torch.slot "t", %[[TUPLE]] : !basicpy.TupleType
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -164,9 +164,11 @@ def AnyTorchType : AnyTypeOf<[
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_ListType,
Basicpy_TupleType,
Basicpy_NoneType,
Basicpy_BytesType,
Torch_NnModuleType,
Torch_OptionalType,
], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_TYPES