mirror of https://github.com/llvm/torch-mlir
Add Conv2D Torchscript Import Support (#167)
Adds support for lowering a torch.nn.Conv2d module to the Torch Dialect through TorchScript import. Generated IR can be viewed here: https://gist.github.com/brycearden/6c0f790115c4577249372ef82768e6fd Required implementing support for tuple in the ivalue importer and list in the node importer.pull/169/head
parent
a375ccf9da
commit
27a4515de2
|
@ -253,6 +253,17 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
npcompListTypeGet(context), elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTuple()) {
|
||||
auto list = ivalue.toTuple()->elements();
|
||||
std::vector<MlirValue> elems;
|
||||
for (const c10::IValue &elem : list) {
|
||||
elems.push_back(importIValue(elem));
|
||||
}
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
||||
npcompTupleTypeGet(context), elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTensor()) {
|
||||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = converTensorToMlirElementsAttr(tensor, loc);
|
||||
|
|
|
@ -123,6 +123,14 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::ListConstruct) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "basicpy.build_list", loc, npcompListTypeGet(context),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::If) {
|
||||
// TorchScript will already have an explicit op to determine truthiness. So
|
||||
// all we need to do here is launder !basicpy.BoolType to i1 for `scf.if`.
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
|
||||
# CHECK: return %[[RET]] : !basicpy.ListType
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def f(t0, t1):
|
||||
return [t0, t1]
|
||||
|
||||
assert isinstance(f, torch.jit.ScriptFunction)
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -0,0 +1,33 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = (1, 2)
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# TODO: Don't lose element type.
|
||||
# CHECK: }
|
||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
# CHECK: %[[TUPLE:.*]] = basicpy.build_tuple %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "t", %[[TUPLE]] : !basicpy.TupleType
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c)
|
||||
mb.module.operation.print()
|
|
@ -164,9 +164,11 @@ def AnyTorchType : AnyTypeOf<[
|
|||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Basicpy_ListType,
|
||||
Basicpy_TupleType,
|
||||
Basicpy_NoneType,
|
||||
Basicpy_BytesType,
|
||||
Torch_NnModuleType,
|
||||
Torch_OptionalType,
|
||||
], "Any type that is legal to pass to a Torch kernel">;
|
||||
|
||||
#endif // TORCH_TYPES
|
||||
|
|
Loading…
Reference in New Issue