mirror of https://github.com/llvm/torch-mlir
Add prim::device and handle derefining for prim::CallMethod
parent
572d198b68
commit
2750d2084c
|
@ -75,6 +75,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
case c10::prim::TupleUnpack:
|
case c10::prim::TupleUnpack:
|
||||||
case c10::prim::ListUnpack:
|
case c10::prim::ListUnpack:
|
||||||
case c10::prim::dtype:
|
case c10::prim::dtype:
|
||||||
|
case c10::prim::device:
|
||||||
case c10::prim::unchecked_cast:
|
case c10::prim::unchecked_cast:
|
||||||
case c10::prim::Uninitialized:
|
case c10::prim::Uninitialized:
|
||||||
case c10::prim::RaiseException:
|
case c10::prim::RaiseException:
|
||||||
|
@ -93,8 +94,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case c10::prim::GetAttr:
|
case c10::prim::GetAttr:
|
||||||
case c10::prim::SetAttr:
|
case c10::prim::SetAttr: {
|
||||||
case c10::prim::CallMethod: {
|
|
||||||
createAndMapNodeWithAttribute(
|
createAndMapNodeWithAttribute(
|
||||||
node, "torch.prim." + std::string(kind.toUnqualString()), "name",
|
node, "torch.prim." + std::string(kind.toUnqualString()), "name",
|
||||||
importAttribute(loc, node, c10::attr::name));
|
importAttribute(loc, node, c10::attr::name));
|
||||||
|
@ -187,6 +187,25 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (kind == c10::prim::CallMethod) {
|
||||||
|
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||||
|
auto methodName = node->s(c10::attr::name);
|
||||||
|
torch::jit::Function *function = classType->findMethod(methodName);
|
||||||
|
torch::jit::Block *calleeEntryBlock = function->graph()->block();
|
||||||
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||||
|
return typeMapper.mapFromTorchType(loc, v->type());
|
||||||
|
});
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "torch.prim.CallMethod", loc,
|
||||||
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
|
||||||
|
appendToBlock),
|
||||||
|
toMlirNamedAttribute("name",
|
||||||
|
importAttribute(loc, node, c10::attr::name)));
|
||||||
|
mapResults(node, operation);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::CallFunction) {
|
if (kind == c10::prim::CallFunction) {
|
||||||
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||||
torch::jit::Block *calleeEntryBlock =
|
torch::jit::Block *calleeEntryBlock =
|
||||||
|
|
|
@ -139,6 +139,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
case TypeKind::StringType: {
|
case TypeKind::StringType: {
|
||||||
return npcompBytesTypeGet(context);
|
return npcompBytesTypeGet(context);
|
||||||
}
|
}
|
||||||
|
case TypeKind::DeviceObjType: {
|
||||||
|
return npcompDeviceTypeGet(context);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
std::stringstream message;
|
std::stringstream message;
|
||||||
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
|
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func private @__torch__.TestModule.forward(
|
||||||
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<i64> {
|
||||||
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
|
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||||
|
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional<i64>) -> !torch.optional<i64>
|
||||||
|
# CHECK: return %[[RET]] : !torch.optional<i64>
|
||||||
|
def forward(self):
|
||||||
|
return self.callee(None)
|
||||||
|
def callee(self, o: typing.Optional[int]):
|
||||||
|
return o
|
||||||
|
|
||||||
|
test_module = TestModule()
|
||||||
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
|
mb.import_module(recursivescriptmodule._c)
|
||||||
|
mb.module.operation.print()
|
|
@ -45,15 +45,15 @@ def prim_RaiseException():
|
||||||
raise Exception("Error")
|
raise Exception("Error")
|
||||||
|
|
||||||
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
|
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> i64 {
|
||||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
# CHECK: %[[C3:.*]] = constant 3 : i64
|
# CHECK: %[[C3:.*]] = constant 3 : i64
|
||||||
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[VAL_0]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
# CHECK: %[[IS_NONE:.*]] = torch.kernel_call "aten::__is__" %[[ARG]], %[[NONE]] : (!torch.optional<i64>, !basicpy.NoneType) -> !basicpy.BoolType
|
||||||
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
|
# CHECK: %[[COND:.*]] = basicpy.bool_cast %[[IS_NONE]] : !basicpy.BoolType -> i1
|
||||||
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
# CHECK: %[[RESULT:.*]] = scf.if %[[COND]] -> (i64) {
|
||||||
# CHECK: scf.yield %[[C3]] : i64
|
# CHECK: scf.yield %[[C3]] : i64
|
||||||
# CHECK: } else {
|
# CHECK: } else {
|
||||||
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[VAL_0]] : !torch.optional<i64> -> i64
|
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<i64> -> i64
|
||||||
# CHECK: scf.yield %[[CASTED]] : i64
|
# CHECK: scf.yield %[[CASTED]] : i64
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
# CHECK: return %[[RESULT:.*]] : i64
|
# CHECK: return %[[RESULT:.*]] : i64
|
||||||
|
@ -102,5 +102,14 @@ def prim_ListUnpack(l: typing.List[int]):
|
||||||
def prim_dtype(x):
|
def prim_dtype(x):
|
||||||
return x.dtype
|
return x.dtype
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.prim_device(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
|
||||||
|
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
|
||||||
|
# CHECK: return %[[RET]] : !torch.Device
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_device(x):
|
||||||
|
return x.device
|
||||||
|
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t);
|
||||||
/** Gets the !torch.optional<T> type with subtype T. */
|
/** Gets the !torch.optional<T> type with subtype T. */
|
||||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.Device type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.Device type */
|
||||||
|
int npcompTypeIsADevice(MlirType t);
|
||||||
|
|
||||||
|
/** Gets the !torch.Device type. */
|
||||||
|
MlirType npcompDeviceTypeGet(MlirContext context);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -561,4 +561,13 @@ def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
|
||||||
|
let summary = "TorchScript prim::device op";
|
||||||
|
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||||
|
let results = (outs Torch_DeviceType:$result);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$tensor attr-dict `:` type($tensor) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
|
@ -79,6 +79,10 @@ def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||||
|
let summary = "Torch device";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -178,6 +182,7 @@ def AnyTorchType : AnyTypeOf<[
|
||||||
Basicpy_BytesType,
|
Basicpy_BytesType,
|
||||||
Torch_NnModuleType,
|
Torch_NnModuleType,
|
||||||
Torch_OptionalType,
|
Torch_OptionalType,
|
||||||
|
Torch_DeviceType,
|
||||||
], "Any type that is legal to pass to a Torch kernel">;
|
], "Any type that is legal to pass to a Torch kernel">;
|
||||||
|
|
||||||
#endif // TORCH_TYPES
|
#endif // TORCH_TYPES
|
||||||
|
|
|
@ -174,3 +174,17 @@ int npcompTypeIsAOptional(MlirType t) {
|
||||||
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
||||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.Device type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.Device type */
|
||||||
|
int npcompTypeIsADevice(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::DeviceType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the !torch.Device type. */
|
||||||
|
MlirType npcompDeviceTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue