Add prim::device and handle derefining for prim::CallMethod

pull/184/head
Sean Silva 2021-03-10 17:25:39 -08:00
parent 572d198b68
commit 2750d2084c
8 changed files with 109 additions and 5 deletions

View File

@ -75,6 +75,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
case c10::prim::TupleUnpack: case c10::prim::TupleUnpack:
case c10::prim::ListUnpack: case c10::prim::ListUnpack:
case c10::prim::dtype: case c10::prim::dtype:
case c10::prim::device:
case c10::prim::unchecked_cast: case c10::prim::unchecked_cast:
case c10::prim::Uninitialized: case c10::prim::Uninitialized:
case c10::prim::RaiseException: case c10::prim::RaiseException:
@ -93,8 +94,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return; return;
} }
case c10::prim::GetAttr: case c10::prim::GetAttr:
case c10::prim::SetAttr: case c10::prim::SetAttr: {
case c10::prim::CallMethod: {
createAndMapNodeWithAttribute( createAndMapNodeWithAttribute(
node, "torch.prim." + std::string(kind.toUnqualString()), "name", node, "torch.prim." + std::string(kind.toUnqualString()), "name",
importAttribute(loc, node, c10::attr::name)); importAttribute(loc, node, c10::attr::name));
@ -187,6 +187,25 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return; return;
} }
if (kind == c10::prim::CallMethod) {
auto classType = node->input(0)->type()->cast<c10::ClassType>();
auto methodName = node->s(c10::attr::name);
torch::jit::Function *function = classType->findMethod(methodName);
torch::jit::Block *calleeEntryBlock = function->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc,
getMlirTypesFromValues(loc, node->outputs()),
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
appendToBlock),
toMlirNamedAttribute("name",
importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation);
return;
}
if (kind == c10::prim::CallFunction) { if (kind == c10::prim::CallFunction) {
auto functionType = node->input(0)->type()->cast<c10::FunctionType>(); auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block *calleeEntryBlock = torch::jit::Block *calleeEntryBlock =

View File

@ -139,6 +139,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
case TypeKind::StringType: { case TypeKind::StringType: {
return npcompBytesTypeGet(context); return npcompBytesTypeGet(context);
} }
case TypeKind::DeviceObjType: {
return npcompDeviceTypeGet(context);
}
default: { default: {
std::stringstream message; std::stringstream message;
message << "unable to map Torch type '" << *torchType << "' to MLIR type"; message << "unable to map Torch type '" << *torchType << "' to MLIR type";

View File

@ -0,0 +1,35 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# CHECK-LABEL: func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<i64> {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional<i64>) -> !torch.optional<i64>
# CHECK: return %[[RET]] : !torch.optional<i64>
def forward(self):
return self.callee(None)
def callee(self, o: typing.Optional[int]):
return o
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()

View File

@ -45,15 +45,15 @@ def prim_RaiseException():
raise Exception("Error") raise Exception("Error")
# CHECK-LABEL: func @__torch__.prim_unchecked_cast( # CHECK-LABEL: func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 { # CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> i64 {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType # CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[C3:.*]] = constant 3 : i64 # CHECK: %[[C3:.*]] = constant 3 : i64
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType # CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[ARG]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1 # CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) { # CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
# CHECK: scf.yield %[[C3]] : i64 # CHECK: scf.yield %[[C3]] : i64
# CHECK: } else { # CHECK: } else {
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64 # CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<i64> -> i64
# CHECK: scf.yield %[[CASTED]] : i64 # CHECK: scf.yield %[[CASTED]] : i64
# CHECK: } # CHECK: }
# CHECK: return %[[RESULT:.*]] : i64 # CHECK: return %[[RESULT:.*]] : i64
@ -102,5 +102,14 @@ def prim_ListUnpack(l: typing.List[int]):
def prim_dtype(x): def prim_dtype(x):
return x.dtype return x.dtype
# CHECK-LABEL: func @__torch__.prim_device(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
# CHECK: return %[[RET]] : !torch.Device
@mb.import_function
@torch.jit.script
def prim_device(x):
return x.device
mb.module.operation.print() mb.module.operation.print()
print() print()

View File

@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t);
/** Gets the !torch.optional<T> type with subtype T. */ /** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType); MlirType npcompOptionalTypeGet(MlirType containedType);
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.Device type */
int npcompTypeIsADevice(MlirType t);
/** Gets the !torch.Device type. */
MlirType npcompDeviceTypeGet(MlirContext context);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -561,4 +561,13 @@ def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
}]; }];
} }
def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
let summary = "TorchScript prim::device op";
let arguments = (ins AnyTorchTensorType:$tensor);
let results = (outs Torch_DeviceType:$result);
let assemblyFormat = [{
$tensor attr-dict `:` type($tensor) `->` type($result)
}];
}
#endif // TORCH_OPS #endif // TORCH_OPS

View File

@ -79,6 +79,10 @@ def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
]; ];
} }
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type predicates // Type predicates
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -178,6 +182,7 @@ def AnyTorchType : AnyTypeOf<[
Basicpy_BytesType, Basicpy_BytesType,
Torch_NnModuleType, Torch_NnModuleType,
Torch_OptionalType, Torch_OptionalType,
Torch_DeviceType,
], "Any type that is legal to pass to a Torch kernel">; ], "Any type that is legal to pass to a Torch kernel">;
#endif // TORCH_TYPES #endif // TORCH_TYPES

View File

@ -174,3 +174,17 @@ int npcompTypeIsAOptional(MlirType t) {
MlirType npcompOptionalTypeGet(MlirType containedType) { MlirType npcompOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType))); return wrap(Torch::OptionalType::get(unwrap(containedType)));
} }
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.Device type */
int npcompTypeIsADevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>();
}
/** Gets the !torch.Device type. */
MlirType npcompDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}