mirror of https://github.com/llvm/torch-mlir
Add prim::device and handle derefining for prim::CallMethod
parent
572d198b68
commit
2750d2084c
|
@ -75,6 +75,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
case c10::prim::TupleUnpack:
|
||||
case c10::prim::ListUnpack:
|
||||
case c10::prim::dtype:
|
||||
case c10::prim::device:
|
||||
case c10::prim::unchecked_cast:
|
||||
case c10::prim::Uninitialized:
|
||||
case c10::prim::RaiseException:
|
||||
|
@ -93,8 +94,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
case c10::prim::GetAttr:
|
||||
case c10::prim::SetAttr:
|
||||
case c10::prim::CallMethod: {
|
||||
case c10::prim::SetAttr: {
|
||||
createAndMapNodeWithAttribute(
|
||||
node, "torch.prim." + std::string(kind.toUnqualString()), "name",
|
||||
importAttribute(loc, node, c10::attr::name));
|
||||
|
@ -187,6 +187,25 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::CallMethod) {
|
||||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||
auto methodName = node->s(c10::attr::name);
|
||||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
torch::jit::Block *calleeEntryBlock = function->graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return typeMapper.mapFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.CallMethod", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
|
||||
appendToBlock),
|
||||
toMlirNamedAttribute("name",
|
||||
importAttribute(loc, node, c10::attr::name)));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::CallFunction) {
|
||||
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||
torch::jit::Block *calleeEntryBlock =
|
||||
|
|
|
@ -139,6 +139,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
case TypeKind::StringType: {
|
||||
return npcompBytesTypeGet(context);
|
||||
}
|
||||
case TypeKind::DeviceObjType: {
|
||||
return npcompDeviceTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# CHECK-LABEL: func private @__torch__.TestModule.forward(
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<i64> {
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional<i64>) -> !torch.optional<i64>
|
||||
# CHECK: return %[[RET]] : !torch.optional<i64>
|
||||
def forward(self):
|
||||
return self.callee(None)
|
||||
def callee(self, o: typing.Optional[int]):
|
||||
return o
|
||||
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c)
|
||||
mb.module.operation.print()
|
|
@ -45,15 +45,15 @@ def prim_RaiseException():
|
|||
raise Exception("Error")
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> i64 {
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
||||
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[ARG]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
||||
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
||||
# CHECK: scf.yield %[[C3]] : i64
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64
|
||||
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<i64> -> i64
|
||||
# CHECK: scf.yield %[[CASTED]] : i64
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RESULT:.*]] : i64
|
||||
|
@ -102,5 +102,14 @@ def prim_ListUnpack(l: typing.List[int]):
|
|||
def prim_dtype(x):
|
||||
return x.dtype
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_device(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
|
||||
# CHECK: return %[[RET]] : !torch.Device
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_device(x):
|
||||
return x.device
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t);
|
|||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.Device type */
|
||||
int npcompTypeIsADevice(MlirType t);
|
||||
|
||||
/** Gets the !torch.Device type. */
|
||||
MlirType npcompDeviceTypeGet(MlirContext context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -561,4 +561,13 @@ def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
|
||||
let summary = "TorchScript prim::device op";
|
||||
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||
let results = (outs Torch_DeviceType:$result);
|
||||
let assemblyFormat = [{
|
||||
$tensor attr-dict `:` type($tensor) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -79,6 +79,10 @@ def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
|
|||
];
|
||||
}
|
||||
|
||||
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||
let summary = "Torch device";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -178,6 +182,7 @@ def AnyTorchType : AnyTypeOf<[
|
|||
Basicpy_BytesType,
|
||||
Torch_NnModuleType,
|
||||
Torch_OptionalType,
|
||||
Torch_DeviceType,
|
||||
], "Any type that is legal to pass to a Torch kernel">;
|
||||
|
||||
#endif // TORCH_TYPES
|
||||
|
|
|
@ -174,3 +174,17 @@ int npcompTypeIsAOptional(MlirType t) {
|
|||
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.Device type */
|
||||
int npcompTypeIsADevice(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DeviceType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.Device type. */
|
||||
MlirType npcompDeviceTypeGet(MlirContext context) {
|
||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue