mirror of https://github.com/llvm/torch-mlir
Implement GlobalizeObjectGraph transformation.
This required restructuring of how we model TorchScript on import. The main difference is that now we split out a `torch.class_type` that holds methods and declarations of the types of each slot. This is more consistent with TorchScript (our previous representation was "denormalized"). Recommended reading order: 1. check out the description of `torch.class_type` in `TorchOps.td` and look at `test/Dialect/Torch/ops.mlir` and `frontends/pytorch/test/module_import/` to familiarize with the new representation. - Just look at the new IR. The diff between the old names and new names is confusing. 2. check out `test/Dialect/Torch/globalize-object-graph*.mlir` and read along with the pass description in `include/npcomp/Dialect/Torch/Transforms/Passes.td` 3. Read the code in `GlobalizeObjectGraph.cpp` and miscellaneous changes in `ivalue_importer.cpp`, `TorchOps.cpp`, etc.pull/162/head
parent
99d1db18d2
commit
158c5c484d
|
@ -101,7 +101,8 @@ public:
|
|||
private:
|
||||
MlirValue rawImportIValue(c10::IValue value);
|
||||
MlirValue importModule(torch::jit::Module jitModule);
|
||||
void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody);
|
||||
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody);
|
||||
void importClassType(c10::ClassType *classType);
|
||||
|
||||
MlirBlock importBlock;
|
||||
MlirContext context;
|
||||
|
@ -111,6 +112,12 @@ private:
|
|||
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
||||
// Used to detect potentially aliasing tensors.
|
||||
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
||||
// The set of ClassType's that have already been imported.
|
||||
//
|
||||
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
||||
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
||||
// them with the MlirOperation that they import into.
|
||||
std::unordered_set<c10::ClassType *> classTypes;
|
||||
// The stack of attribute names we have traversed to reach the current IValue.
|
||||
// Used for diagnostics.
|
||||
std::vector<std::string> attributeNameStack;
|
||||
|
@ -128,16 +135,25 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
// TODO: Can we do better?
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
||||
MlirOperation nnModule =
|
||||
createMlirOperation("torch.nn_module", loc,
|
||||
npcompNnModuleTypeGet(context), mlirRegionCreate());
|
||||
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
||||
if (!maybeName) {
|
||||
throw std::invalid_argument("cannot import unnamed module");
|
||||
}
|
||||
std::string moduleTypeName = maybeName->qualifiedName();
|
||||
|
||||
// Ensure the class type has been imported.
|
||||
importClassType(currentModule.type().get());
|
||||
|
||||
MlirOperation nnModule = createMlirOperation(
|
||||
"torch.nn_module", loc,
|
||||
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
mlirRegionCreate());
|
||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||
|
||||
if (!rootModuleName.has_value()) {
|
||||
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
||||
rootModuleName = maybeName ? maybeName->qualifiedName() : "unnamed module";
|
||||
rootModuleName = moduleTypeName;
|
||||
}
|
||||
|
||||
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
||||
|
@ -151,7 +167,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
MlirValue slotValue = importIValue(slots[i]);
|
||||
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
||||
createMlirOperationAtEnd(
|
||||
nnModuleBody, "torch.attr", loc, slotValue,
|
||||
nnModuleBody, "torch.slot", loc, slotValue,
|
||||
toMlirNamedAttribute(
|
||||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))));
|
||||
|
@ -162,10 +178,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
rootModuleName = c10::nullopt;
|
||||
}
|
||||
|
||||
for (torch::jit::Function *function : currentModule.type()->methods()) {
|
||||
importMethod(function, nnModuleBody);
|
||||
}
|
||||
|
||||
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
||||
mlirBlockInsertOwnedOperationBefore(
|
||||
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
||||
|
@ -262,7 +274,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
|
||||
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||
MlirBlock nnModuleBody) {
|
||||
MlirBlock classTypeBody) {
|
||||
// We make an effort for the func op's symbol name to be useful for debugging,
|
||||
// but still clearly non-load-bearing.
|
||||
std::string symName =
|
||||
|
@ -275,13 +287,50 @@ void IValueImporter::importMethod(torch::jit::Function *function,
|
|||
mlirBlockInsertOwnedOperationBefore(
|
||||
importBlock, mlirBlockGetTerminator(importBlock), func);
|
||||
createMlirOperationAtEnd(
|
||||
nnModuleBody, "torch.method", mlirLocationUnknownGet(context),
|
||||
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
|
||||
toMlirNamedAttribute(
|
||||
"name",
|
||||
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
||||
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
|
||||
context, toMlirStringRef(symName))));
|
||||
}
|
||||
|
||||
void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||
if (!classTypes.insert(classType).second) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: Can we do better?
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
||||
MlirOperation op = createMlirOperationAtEnd(
|
||||
importBlock, "torch.class_type", loc, mlirRegionCreate(),
|
||||
toMlirNamedAttribute(
|
||||
"sym_name",
|
||||
mlirStringAttrGet(
|
||||
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
||||
MlirRegion region = mlirOperationGetRegion(op, 0);
|
||||
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
||||
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
||||
|
||||
for (const c10::ClassAttribute &classAttribute : classType->getAttributes()) {
|
||||
createMlirOperationAtEnd(
|
||||
classTypeBody, "torch.attr", loc,
|
||||
toMlirNamedAttribute(
|
||||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))),
|
||||
toMlirNamedAttribute("type",
|
||||
mlirTypeAttrGet(typeMapper.mapFromTorchType(
|
||||
loc, classAttribute.getType()))));
|
||||
}
|
||||
|
||||
for (torch::jit::Function *function : classType->methods()) {
|
||||
importMethod(function, classTypeBody);
|
||||
}
|
||||
|
||||
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
||||
}
|
||||
|
||||
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||
MlirContext context) {
|
||||
// When debugging module importing, it can be useful to dump as so:
|
||||
|
|
|
@ -106,11 +106,19 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
return npcompNnModuleTypeGet(context);
|
||||
auto maybeName = torchType->cast<c10::ClassType>()->name();
|
||||
return npcompNnModuleTypeGet(
|
||||
context, toMlirStringRef(maybeName ? maybeName->qualifiedName()
|
||||
: "unnamed class"));
|
||||
}
|
||||
case TypeKind::FloatType: {
|
||||
return mlirF64TypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return npcompOptionalTypeGet(
|
||||
mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::IntType: {
|
||||
return mlirIntegerTypeGet(context, 64);
|
||||
}
|
||||
|
@ -120,6 +128,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
case TypeKind::BoolType: {
|
||||
return npcompBoolTypeGet(context);
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
// TODO: Don't lose the element type information.
|
||||
return npcompListTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unable to map Torch type " << *torchType << " to MLIR type";
|
||||
|
|
|
@ -15,13 +15,16 @@ class TestModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l = [1, 2]
|
||||
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# TODO: Don't lose element type.
|
||||
# CHECK: torch.attr "l" : !basicpy.ListType
|
||||
# CHECK: }
|
||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.attr "l", %[[LIST]] : !basicpy.ListType
|
||||
# CHECK: }
|
||||
# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
|
|
|
@ -32,8 +32,8 @@ class TestModule(torch.nn.Module):
|
|||
# the case that the name is `__main__` Torch replaces it with `__torch__` to
|
||||
# avoid collisions.
|
||||
|
||||
# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward
|
||||
# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward
|
||||
# CHECK-DAG: func private @__npcomp_priv_fn.__torch__.TestModule.forward
|
||||
# CHECK=DAG: func private @__npcomp_priv_fn.__torch__.Submodule.forward
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
|
|
|
@ -19,17 +19,22 @@ class TestModule(torch.nn.Module):
|
|||
|
||||
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
|
||||
|
||||
# CHECK-LABEL: torch.class_type
|
||||
# CHECK-SAME: @[[CLASSTYPE:.*]] {
|
||||
# CHECK: torch.method "forward", @[[SYMNAME:.*]]
|
||||
# CHECK: }
|
||||
|
||||
|
||||
# CHECK-LABEL: func private
|
||||
# CHECK-SAME: @[[SYMNAME:.*]](
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
|
||||
# CHECK-SAME: @[[SYMNAME]](
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"[[CLASSTYPE]]">,
|
||||
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
|
||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.method "forward", @[[SYMNAME]]
|
||||
# CHECK: }
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
|
|
|
@ -24,8 +24,8 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: %[[L2:.*]] = basicpy.build_list
|
||||
# CHECK: %[[L1:.*]] = basicpy.build_list
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.attr "l2", %[[L2]]
|
||||
# CHECK: torch.attr "l1", %[[L1]]
|
||||
# CHECK: torch.slot "l2", %[[L2]]
|
||||
# CHECK: torch.slot "l1", %[[L1]]
|
||||
self.l2 = self.l1 = [1]
|
||||
|
||||
# This can be uncommented when the graph importer supports it.
|
||||
|
|
|
@ -16,8 +16,8 @@ class TestModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
# CHECK: %[[A:.*]] = numpy.create_array_from_tensor
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.attr "t1", %[[A]]
|
||||
# CHECK: torch.attr "t2", %[[A]]
|
||||
# CHECK: torch.slot "t1", %[[A]]
|
||||
# CHECK: torch.slot "t2", %[[A]]
|
||||
self.t1 = self.t2 = torch.tensor([10., 20.])
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ class TestModule(torch.nn.Module):
|
|||
self.t2 = torch.ones(1)
|
||||
|
||||
# CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}(
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module) -> !basicpy.NoneType {
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !basicpy.NoneType {
|
||||
def forward(self):
|
||||
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
|
||||
# CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]]
|
||||
|
@ -26,7 +26,7 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
|
||||
self.callee(self.t1, self.t2)
|
||||
# CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}(
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
|
||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
|
||||
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
|
||||
def callee(self, x, y):
|
||||
|
|
|
@ -17,14 +17,20 @@ class TestModule(torch.nn.Module):
|
|||
self.i = 3
|
||||
self.f = 42.5
|
||||
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# CHECK: torch.attr "training" : !basicpy.BoolType
|
||||
# CHECK: torch.attr "i" : i64
|
||||
# CHECK: torch.attr "f" : f64
|
||||
# CHECK: }
|
||||
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
|
||||
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
# CHECK: %[[MODULE:.*]] = torch.nn_module {
|
||||
# Note: for some reason, Torch always adds a "training" property to all modules.
|
||||
# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType
|
||||
# CHECK: torch.attr "i", %[[N3]] : i64
|
||||
# CHECK: torch.attr "f", %[[N42]] : f64
|
||||
# CHECK: torch.slot "training", %[[TRUE]] : !basicpy.BoolType
|
||||
# CHECK: torch.slot "i", %[[N3]] : i64
|
||||
# CHECK: torch.slot "f", %[[N42]] : f64
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE:.*]]">
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
|
|
|
@ -15,6 +15,8 @@ class Submodule(torch.nn.Module):
|
|||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.n = n
|
||||
def forward(self):
|
||||
return self.n
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -27,9 +29,9 @@ class TestModule(torch.nn.Module):
|
|||
# Modules with the same class can be selected between.
|
||||
# CHECK: %[[MOD:.*]] = scf.if
|
||||
s = self.s1 if b else self.s2
|
||||
# CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"]
|
||||
# CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] ()
|
||||
# CHECK: return %[[N]]
|
||||
return s.n
|
||||
return s.forward()
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
|
|
|
@ -26,20 +26,20 @@ class TestModule(torch.nn.Module):
|
|||
|
||||
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
|
||||
# CHECK: %[[S0:.*]] = torch.nn_module {
|
||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.attr "n", %[[N0]] : i64
|
||||
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.slot "n", %[[N0]] : i64
|
||||
# CHECK: }
|
||||
|
||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
# CHECK: %[[S1:.*]] = torch.nn_module {
|
||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.attr "n", %[[N1]] : i64
|
||||
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.slot "n", %[[N1]] : i64
|
||||
# CHECK: }
|
||||
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module
|
||||
# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module
|
||||
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||
# CHECK: torch.slot "s0", %[[S0]] : !torch.nn.Module
|
||||
# CHECK: torch.slot "s1", %[[S1]] : !torch.nn.Module
|
||||
# CHECK: }
|
||||
|
||||
|
||||
|
|
|
@ -23,8 +23,8 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.slot "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: torch.slot "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
|
||||
|
||||
|
|
|
@ -124,8 +124,18 @@ MlirType npcompTupleTypeGet(MlirContext context);
|
|||
/** Checks whether the given type is a torch.nn.Module type */
|
||||
int npcompTypeIsANnModule(MlirType t);
|
||||
|
||||
/** Gets the singleton torch.nn.Module type. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context);
|
||||
/** Gets the !torch.nn.Module type of the specified class. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.optional type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.optional<T> type */
|
||||
int npcompTypeIsAOptional(MlirType t);
|
||||
|
||||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -49,10 +49,11 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TorchScript modeling ops.
|
||||
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
||||
let summary = "Constructs a torch.nn.Module";
|
||||
let description = [{
|
||||
|
@ -66,14 +67,18 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
|
||||
```mlir
|
||||
%2 = torch.nn_module {
|
||||
torch.attr "b", %bool_true : !basicpy.BoolType
|
||||
torch.attr "i", %num3_i64 : i64
|
||||
torch.attr "f", %num : f64
|
||||
torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "submodule", %1 : !torch.nn.Module
|
||||
torch.method "method", @f
|
||||
}
|
||||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "submodule", %1 : !torch.nn.Module
|
||||
} : !torch.nn.Module<"my_class_name">
|
||||
```
|
||||
|
||||
This op is tightly coupled to the `torch.class_type` op named in the
|
||||
`!torch.nn.Module<"my_class_name">` type. Each slot must match precisely
|
||||
with the corresponding `torch.attr` in the `torch.class_type`.
|
||||
See the documentation for `torch.class_type` for information.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
@ -81,7 +86,11 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
let regions = (region SizedRegion<1>:$region);
|
||||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = "$region attr-dict";
|
||||
let assemblyFormat = "$region attr-dict `:` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getClassName() { return getType().getClassName(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||
|
@ -94,13 +103,13 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
|||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def Torch_AttrOp : Torch_Op<"attr", [
|
||||
def Torch_SlotOp : Torch_Op<"slot", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||
let summary = "Define an attribute of a torch.nn.Module";
|
||||
let summary = "Define the value of a slot of a torch.nn.Module";
|
||||
let description = [{
|
||||
This op declaratively specifies that the parent torch.nn_module has an
|
||||
attribute `name` with value `value`, which is allowed to be an arbitrary
|
||||
Torch-compatible SSA value, including other torch.nn.Module's.
|
||||
This op specifies that the initial value of the slot `name` of the
|
||||
parent torch.nn_module should be `value`, which is allowed to be an
|
||||
arbitrary Torch-compatible SSA value, including other !torch.nn.Module's.
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
|
||||
|
@ -111,13 +120,73 @@ def Torch_AttrOp : Torch_Op<"attr", [
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Modeling of TorchScript class types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
||||
Symbol,
|
||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
|
||||
let summary = "Constructs a torch.ClassType";
|
||||
let description = [{
|
||||
Declares a class type. Class types are the types used to describe
|
||||
TorchScript `torch.nn.Module`'s. The terminology "class type" is for
|
||||
consistency with TorchScript (a better name in our context might be
|
||||
"nn module subtype"). The `syn_name` of this op is the same string
|
||||
as in the `!torch.nn.Module<"...">` type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// A simple empty torch.class_type, with corresponding torch.nn_module.
|
||||
torch.class_type @empty {}
|
||||
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
|
||||
|
||||
// A class type with many members.
|
||||
torch.class_type @test {
|
||||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||
torch.method "method", @f
|
||||
}
|
||||
torch.nn_module {
|
||||
// These must match the order and names in the `torch.class_type`.
|
||||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||
} : !torch.nn.Module<"test">
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SymbolNameAttr:$sym_name);
|
||||
let results = (outs);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = "$sym_name $region attr-dict";
|
||||
}
|
||||
|
||||
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
|
||||
let summary = "Implicit terminator for torch.class_type";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def Torch_MethodOp : Torch_Op<"method", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">,
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Define a method of a torch.nn.Module";
|
||||
let summary = "Declare a method of a torch.class_type";
|
||||
let description = [{
|
||||
This op declaratively specifies that the parent torch.nn_module has a
|
||||
This op declaratively specifies that the parent torch.class_type has a
|
||||
method `name` which calls `function`. `function` is an unbound function.
|
||||
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
|
||||
"self" object).
|
||||
|
@ -131,6 +200,88 @@ def Torch_MethodOp : Torch_Op<"method", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AttrOp : Torch_Op<"attr", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
|
||||
]> {
|
||||
let summary = "Declare an attribute of a torch.class_type";
|
||||
let description = [{
|
||||
This op declaratively specifies that torch.nn.Module's of the parent
|
||||
torch.class_type must have an attribute `name` of type `type`.
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$name, TypeAttr:$type);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$name `:` $type attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Global slot ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: Should these be in a separate dialect?
|
||||
// At this point, they are fairly specific to torch types, but their get/set
|
||||
// semantics follow Python.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||
Symbol,
|
||||
IsolatedFromAbove,
|
||||
]> {
|
||||
let summary = "A slot with global storage";
|
||||
let description = [{
|
||||
Represents a slot with global storage. The slot semantics are the same
|
||||
as Python's: getting or setting a slot is done by object identity.
|
||||
}];
|
||||
|
||||
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$typeBound);
|
||||
let results = (outs);
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
$sym_name attr-dict `:` $typeBound
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// The name of the function, which, for semantic correctness, must be called
|
||||
// exactly once and this call must be done before any other calls into
|
||||
// the module.
|
||||
// TODO: Avoid load-bearing names.
|
||||
// We could replace this with an op that marks the function as initializer.
|
||||
static constexpr StringRef getGlobalSlotInitializerFuncName() {
|
||||
return "__torch_global_slot_initializer";
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
|
||||
let summary = "Get the value stored in a torch.global_slot";
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$slot
|
||||
);
|
||||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$slot attr-dict `:` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
||||
let summary = "Set the value stored in a torch.global_slot";
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$slot,
|
||||
AnyTorchType:$value
|
||||
);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$slot `=` $value attr-dict `:` type($value)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TorchScript `prim::` ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -142,7 +293,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` attr-dict `:` type($result)
|
||||
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -157,7 +308,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` `=` $value attr-dict `:` type($value)
|
||||
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -172,7 +323,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
|
|||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($operands) `->` type($result)
|
||||
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -22,9 +22,62 @@ class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name
|
|||
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||
let summary = "torch.nn.Module";
|
||||
let description = [{
|
||||
Represents an instance of a `torch.nn.Module` with the given `className`.
|
||||
}];
|
||||
let parameters = (ins StringRefParameter<"class name">:$className);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "nn.Module<\"";
|
||||
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
|
||||
$_printer << "\">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
StringRef className;
|
||||
if ($_parser.parseOptionalString(&className))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, className);
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO: It feels like this should be something more general.
|
||||
// However, to do that, we need to agree on construction operations
|
||||
// and the valid MLIR representations of the "None" state.
|
||||
//
|
||||
// For now, we only need it as a stand-in type to allow importing
|
||||
// the `_is_full_backward_hook` optional bool type that Torch puts on
|
||||
// all classes.
|
||||
def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
|
||||
let summary = "!torch.optional<T>";
|
||||
let description = [{
|
||||
}];
|
||||
let parameters = (ins "::mlir::Type":$containedType);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "optional<" << getImpl()->containedType << ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
Type containedType;
|
||||
if ($_parser.parseType(containedType))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, containedType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||
return Base::get(containedType.getContext(), containedType);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPTorchPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc NPCOMPTorchTransforms ./)
|
|
@ -0,0 +1,30 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
void registerTorchPasses();
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,65 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_TORCH_PASSES
|
||||
#define NPCOMP_TORCH_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||
let summary = "Converts TorchScript object graphs to a globalized form";
|
||||
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
|
||||
let description = [{
|
||||
This pass converts a subset of possible TorchScript modules into a
|
||||
more restrictive lower-level form that strips away the need to be
|
||||
concerned with instances of !torch.nn.Module<...> type. Specifically,
|
||||
the object graph is flattened into a set of discrete globals
|
||||
(`torch.global_slot`) that hold the program state.
|
||||
|
||||
The overarching goal is for a strict correspondence between the original
|
||||
`torch.nn.Module` (call it `root`) that the user `torch.jit.script`'ed, and
|
||||
the public interface of the resulting MLIR module. Specifically:
|
||||
- The call `root.encoder.forward(...)` in Python corresponds to invoking
|
||||
the `func @encoder.forward` on the resulting MLIR module.
|
||||
- The data member access `root.decoder.ids_to_strings_table` in Python
|
||||
corresponds to accessing the
|
||||
`torch.global_slot @decoder.ids_to_strings_table` on the resulting
|
||||
MLIR module.
|
||||
In effect, the entire MLIR module corresponds to an instance of the `root`
|
||||
object. This matches with the intuitive behavior desired for deployment:
|
||||
When the MLIR module (or, more likely, a compiled artifact derived from it)
|
||||
is loaded in a deployed environment, it is equivalent to recreating the
|
||||
original `root` object.
|
||||
|
||||
This pass performs a complete change of the externally visible calling
|
||||
convention of the MLIR module for a graph of objects and methods to a
|
||||
fixed set of globals and functions.
|
||||
|
||||
Of course, only a subset of programs can be transformed, and this pass fails
|
||||
with an error if the conditions are violated.
|
||||
|
||||
Specifically, the restrictions are:
|
||||
- There must be a unique torch.nn_module that is not the value of a slot
|
||||
of any other torch.nn_module
|
||||
- Rationale: Allows us to have a notion of a unique "root" op, which is
|
||||
used to define linkage. This also matches how TorchScript imports in
|
||||
practice (`torch.jit.script` imports a single root object).
|
||||
- There must be exactly one instance of each torch.class_type. Equivalently,
|
||||
Every torch.nn_module must have a distinct type.
|
||||
- Rationale: This guarantee precludes things like selecting between
|
||||
multiple modules dynamically at runtime, which would require indirecting
|
||||
between the separate storage of each instance.
|
||||
- All torch.nn_module's must be reachable by a unique path from the root
|
||||
- Rationale: Eliminates possibility of potentially exponential number of
|
||||
paths. Or worse, infinite number of paths when considering cyclic
|
||||
object graphs. Also as of Feb 2021, TorchScript won't import into
|
||||
this form (it has a bug related to the identity of submodules).
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_TORCH_PASSES
|
|
@ -156,7 +156,21 @@ int npcompTypeIsANnModule(MlirType t) {
|
|||
return unwrap(t).isa<Torch::NnModuleType>();
|
||||
}
|
||||
|
||||
/** Gets the singleton torch.nn.Module type. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NnModuleType::get(unwrap(context)));
|
||||
/** Gets the torch.nn.Module type of the specified class. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
|
||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.optional type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.optional<T> type */
|
||||
int npcompTypeIsAOptional(MlirType t) {
|
||||
return unwrap(t).isa<Torch::OptionalType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
@ -49,8 +50,23 @@ KernelMetadata KernelCallOp::getTorchKernelMetadata() {
|
|||
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
|
||||
if (!func)
|
||||
return emitError() << "'" << function()
|
||||
return emitError() << "'@" << function()
|
||||
<< "' does not reference a valid function";
|
||||
if (func.getVisibility() != SymbolTable::Visibility::Private)
|
||||
return emitError() << "'@" << function()
|
||||
<< "' must reference a private function";
|
||||
if (func.isDeclaration())
|
||||
return emitError() << "'@" << function()
|
||||
<< "' must reference a function that is defined (not "
|
||||
"merely declared)";
|
||||
auto expectedReceiverArgType = NnModuleType::get(
|
||||
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
||||
if (func.getType().getNumInputs() == 0 ||
|
||||
func.getType().getInput(0) != expectedReceiverArgType) {
|
||||
return emitError() << "the referenced function '" << function()
|
||||
<< "' must have a first argument of type "
|
||||
<< expectedReceiverArgType;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -60,8 +76,82 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
|
||||
static LogicalResult verify(NnModuleOp op) {
|
||||
for (Operation &child : *op.getBody())
|
||||
if (!isa<AttrOp, MethodOp, NnModuleTerminatorOp>(&child))
|
||||
return child.emitOpError() << "is not allowed inside `torch.nn_module`";
|
||||
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
||||
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
||||
return success();
|
||||
}
|
||||
|
||||
// PyTorch has a well-developed notion of subtyping.
|
||||
//
|
||||
// This is a restricted subset of it.
|
||||
//
|
||||
// TODO: Flesh this out.
|
||||
bool isValidSubtype(Type subtype, Type type) {
|
||||
if (subtype == type)
|
||||
return true;
|
||||
if (auto optional = type.dyn_cast<OptionalType>())
|
||||
return subtype == optional.getContainedType() ||
|
||||
subtype.isa<Basicpy::NoneType>();
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto classType =
|
||||
symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(*this, getClassName());
|
||||
if (!classType)
|
||||
return emitError() << "'" << getClassName()
|
||||
<< "' does not reference a valid class type";
|
||||
|
||||
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
|
||||
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
|
||||
if (attrs.size() != attrDefs.size())
|
||||
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
|
||||
"match number of 'torch.attr's in "
|
||||
"the corresponding 'torch.class_type'";
|
||||
for (int i = 0, e = attrs.size(); i != e; i++) {
|
||||
SlotOp attr = attrs[i];
|
||||
AttrOp attrDef = attrDefs[i];
|
||||
if (!isValidSubtype(attr.value().getType(), attrDef.type()) ||
|
||||
attr.name() != attrDef.name()) {
|
||||
return attr.emitOpError()
|
||||
.append("is expected to match type and name of '",
|
||||
attrDef.getOperation(), "'")
|
||||
.attachNote(attrDef.getLoc())
|
||||
.append("see torch.attr at corresponding index ", i, " here");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ClassTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ClassTypeOp op) {
|
||||
llvm::StringMap<Operation *> namesToOps;
|
||||
for (Operation &child : op.getBody()->without_terminator()) {
|
||||
if (!isa<AttrOp, MethodOp>(&child))
|
||||
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
||||
StringRef name;
|
||||
if (auto attr = dyn_cast<AttrOp>(child))
|
||||
name = attr.name();
|
||||
else
|
||||
name = cast<MethodOp>(child).name();
|
||||
auto itAndWasInserted = namesToOps.insert({name, &child});
|
||||
auto it = itAndWasInserted.first;
|
||||
bool wasInserted = itAndWasInserted.second;
|
||||
if (!wasInserted) {
|
||||
auto diag = op.emitOpError().append(
|
||||
"has duplicate attr/method with name '", name, "'");
|
||||
diag.attachNote(it->second->getLoc())
|
||||
.append("see first conflicting attr/method here");
|
||||
diag.attachNote(child.getLoc())
|
||||
.append("see second conflicting attr/method here");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
||||
|
||||
DEPENDS
|
||||
NPCOMPTorchPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
NPCOMPTorchDialect
|
||||
NPCOMPBasicpyDialect
|
||||
)
|
|
@ -0,0 +1,340 @@
|
|||
//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
// See the pass documentation for `torch-globalize-object-graph`.
|
||||
class ObjectGraphGlobalizer {
|
||||
public:
|
||||
ObjectGraphGlobalizer(ModuleOp module);
|
||||
LogicalResult globalizeObjectGraph();
|
||||
|
||||
private:
|
||||
FailureOr<NnModuleOp> findRootNnModule();
|
||||
LogicalResult checkSingleInstanceOfEachClass();
|
||||
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
|
||||
void createInitializerFunc();
|
||||
LogicalResult rewriteMethods();
|
||||
void removeObjectGraph();
|
||||
|
||||
ModuleOp module;
|
||||
SymbolTable symbolTable;
|
||||
OpBuilder globalBuilder;
|
||||
// The stack of attribute names we have traversed during our recursive
|
||||
// traversal of the class/object hierarchy.
|
||||
//
|
||||
// Linkage names are calculated based on the set of attribute names traversed
|
||||
// from the root class/module in the program.
|
||||
SmallVector<std::string> nameStack;
|
||||
|
||||
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
|
||||
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
|
||||
// + attr name since frequently that is all one has access to and it
|
||||
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
|
||||
// with the corresponding name.
|
||||
using AttrOfClass =
|
||||
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
|
||||
// The initial value associated with an attribute of a class.
|
||||
// Since we only allow a single instance of a class, this is equivalent to
|
||||
// the initial value of the unique slot corresponding to that attr.
|
||||
DenseMap<AttrOfClass, Value> slotInitialValues;
|
||||
// The inverse map of `slotInitialValues`.
|
||||
// Many attributes can have the same initial value, so the value type
|
||||
// is a vector.
|
||||
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
|
||||
|
||||
// The torch.global_slot corresponding to each torch.attr/torch.slot.
|
||||
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
|
||||
// The linkage name (value) for the function with symbol name (key).
|
||||
DenseMap<StringRef, std::string> methodLinkageNames;
|
||||
|
||||
// The set of class types that have already been processed.
|
||||
// Used for diagnostics.
|
||||
// The map value is the original path from the root that we found it at.
|
||||
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
|
||||
: module(module), symbolTable(module),
|
||||
globalBuilder(module.getBodyRegion()) {}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
|
||||
// We require there to be a unique root !torch.nn.Module.
|
||||
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
|
||||
if (failed(maybeRootNnModule))
|
||||
return failure();
|
||||
NnModuleOp rootNnModule = *maybeRootNnModule;
|
||||
if (!rootNnModule)
|
||||
return module.emitError()
|
||||
<< "module does not contain a root torch.nn_module";
|
||||
|
||||
// We require one instance of each class. That is, there is a single
|
||||
// torch.nn_module for each torch.class_type.
|
||||
if (failed(checkSingleInstanceOfEachClass()))
|
||||
return failure();
|
||||
|
||||
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
|
||||
for (auto slot : nnModule.getOps<SlotOp>()) {
|
||||
AttrOfClass attrOfClass = {classType, slot.name()};
|
||||
slotInitialValues[attrOfClass] = slot.value();
|
||||
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively traverse the class hierarchy, globalizing slots and
|
||||
// tracking linkage names for methods.
|
||||
auto rootClassType =
|
||||
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
|
||||
if (failed(recursivelyTraverseClassType(rootClassType)))
|
||||
return failure();
|
||||
|
||||
// Move all slot initial values into an initializer func.
|
||||
createInitializerFunc();
|
||||
|
||||
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
|
||||
if (failed(rewriteMethods()))
|
||||
return failure();
|
||||
|
||||
// Now that all we have finished converting to the new form, remove
|
||||
// the original object graph.
|
||||
removeObjectGraph();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
||||
NnModuleOp rootNnModule;
|
||||
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||
if (!op.use_empty())
|
||||
continue;
|
||||
if (rootNnModule) {
|
||||
op.emitError()
|
||||
.append("found more than one root module (module that is not a "
|
||||
"child of any other module)")
|
||||
.attachNote(rootNnModule.getLoc())
|
||||
.append("see other root module here");
|
||||
return failure();
|
||||
}
|
||||
rootNnModule = op;
|
||||
}
|
||||
return rootNnModule;
|
||||
}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
|
||||
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
|
||||
classInstances;
|
||||
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
|
||||
classInstances[classType].push_back(op);
|
||||
}
|
||||
for (auto &p : classInstances) {
|
||||
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
|
||||
ArrayRef<NnModuleOp> instances = p.second;
|
||||
if (instances.size() > 1) {
|
||||
// TODO: Improve this diagnostic based on user use cases.
|
||||
// This is a user-facing diagnostic that enforces key invariants to
|
||||
// our TorchScript subset.
|
||||
auto diag = classType.emitError(
|
||||
"class type has more than one instance: the current TorchScript "
|
||||
"supported subset only allows single instances");
|
||||
for (NnModuleOp instance : instances) {
|
||||
diag.attachNote(instance.getLoc()) << "see instance here";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
|
||||
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
||||
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
|
||||
return classType.emitError()
|
||||
<< "reachable by multiple paths from root object: '<root>."
|
||||
<< seenClassTypes[classType] << "' and '<root>."
|
||||
<< pathToClassFromRoot << "'";
|
||||
}
|
||||
|
||||
// For each attr, create a global slot for it.
|
||||
for (auto attr : classType.getOps<AttrOp>()) {
|
||||
nameStack.push_back(attr.name().str());
|
||||
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
|
||||
recursivelyTraverseClassType(
|
||||
symbolTable.lookup<ClassTypeOp>(type.getClassName()));
|
||||
} else {
|
||||
auto linkageName = llvm::join(nameStack, ".");
|
||||
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
|
||||
attr->getLoc(), linkageName, TypeAttr::get(attr.type()));
|
||||
AttrOfClass attrOfClass = {classType, attr.name()};
|
||||
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
|
||||
globalSlotForAttr[attrOfClass] = globalSlot;
|
||||
}
|
||||
nameStack.pop_back();
|
||||
}
|
||||
// For each method, track the linkage name it will eventually have.
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
nameStack.push_back(method.name().str());
|
||||
auto linkageName = llvm::join(nameStack, ".");
|
||||
nameStack.pop_back();
|
||||
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
|
||||
method.emitError() << "unbound function shared by multiple methods";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ObjectGraphGlobalizer::createInitializerFunc() {
|
||||
auto loc = module.getLoc();
|
||||
auto func = globalBuilder.create<FuncOp>(
|
||||
loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(),
|
||||
globalBuilder.getFunctionType({}, {}));
|
||||
OpBuilder builder(func.getContext());
|
||||
Block *body = builder.createBlock(&func.getBody());
|
||||
|
||||
SmallVector<Operation *> opsToMove;
|
||||
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
||||
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, FuncOp, ModuleTerminatorOp>(
|
||||
&op))
|
||||
continue;
|
||||
op.moveBefore(body, body->end());
|
||||
for (Value result : llvm::make_early_inc_range(op.getResults())) {
|
||||
auto it = slotInitialValuesInverseMap.find(result);
|
||||
if (it == slotInitialValuesInverseMap.end())
|
||||
continue;
|
||||
for (AttrOfClass attrOfClass : it->second) {
|
||||
GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass];
|
||||
OpBuilder::atBlockEnd(body).create<GlobalSlotSetOp>(
|
||||
globalSlot.getLoc(), globalSlot.sym_name(), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
builder.create<ReturnOp>(loc);
|
||||
}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
|
||||
DenseMap<AttrOfClass, StringRef> linkageNames;
|
||||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
auto it = methodLinkageNames.find(method.function());
|
||||
if (it == methodLinkageNames.end())
|
||||
continue;
|
||||
linkageNames[{classType, method.name()}] = it->second;
|
||||
}
|
||||
}
|
||||
// We only handle a small subset of ops that conform with the set of
|
||||
// assumptions that allow us to globalize the object graph. Anything that
|
||||
// tries to treat modules as bona-fide objects and not just namespaces
|
||||
// of methods with a single instance of the corresponding type just gets
|
||||
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
|
||||
// of modules, or there is an scf.if selecting between modules, etc.
|
||||
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
|
||||
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
||||
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
|
||||
OpBuilder(primSetAttr)
|
||||
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
|
||||
primSetAttr.value());
|
||||
primSetAttr.erase();
|
||||
}
|
||||
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
primGetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
||||
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
|
||||
auto globalSlotGet = OpBuilder(primGetAttr)
|
||||
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
|
||||
primGetAttr.getType(),
|
||||
globalSlot.sym_name());
|
||||
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
|
||||
primGetAttr.erase();
|
||||
}
|
||||
if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
|
||||
.getType()
|
||||
.cast<NnModuleType>()
|
||||
.getClassName());
|
||||
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
|
||||
auto call = OpBuilder(primCallMethod)
|
||||
.create<CallOp>(primCallMethod.getLoc(), linkageName,
|
||||
primCallMethod.getType(),
|
||||
primCallMethod.operands());
|
||||
primCallMethod.replaceAllUsesWith(call);
|
||||
primCallMethod.erase();
|
||||
}
|
||||
};
|
||||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
auto it = methodLinkageNames.find(method.function());
|
||||
if (it == methodLinkageNames.end())
|
||||
continue;
|
||||
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
|
||||
func.setVisibility(SymbolTable::Visibility::Public);
|
||||
func.setName(it->second);
|
||||
func.walk(rewriteOpWithNnModuleTypeOperand);
|
||||
SmallVector<unsigned> argsToErase;
|
||||
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||
if (!arg.value().getType().isa<NnModuleType>())
|
||||
continue;
|
||||
if (!arg.value().use_empty()) {
|
||||
// TODO: Improve this based on real user use cases.
|
||||
// This is a diagnostic that users will hit if they do not conform to
|
||||
// the supported subset of TorchScript.
|
||||
auto diag = func.emitError().append(
|
||||
"func argument at index ", arg.index(),
|
||||
" has uses that were not able to be converted");
|
||||
for (Operation *user : arg.value().getUsers())
|
||||
diag.attachNote(user->getLoc()).append("see user here");
|
||||
return failure();
|
||||
}
|
||||
argsToErase.push_back(arg.index());
|
||||
}
|
||||
func.eraseArguments(argsToErase);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ObjectGraphGlobalizer::removeObjectGraph() {
|
||||
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
||||
if (isa<ClassTypeOp, NnModuleOp>(op))
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GlobalizeObjectGraphPass
|
||||
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
|
||||
void runOnOperation() override {
|
||||
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
||||
return std::make_unique<GlobalizeObjectGraphPass>();
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Torch {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
|
@ -0,0 +1,20 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); }
|
|
@ -21,6 +21,7 @@
|
|||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Dialect/TCP/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "npcomp/Typing/Transforms/Passes.h"
|
||||
|
||||
#include "npcomp/Conversion/Passes.h"
|
||||
|
@ -47,5 +48,6 @@ void mlir::NPCOMP::registerAllPasses() {
|
|||
mlir::NPCOMP::registerNumpyPasses();
|
||||
mlir::NPCOMP::registerTCFPasses();
|
||||
mlir::NPCOMP::registerTCPPasses();
|
||||
mlir::NPCOMP::registerTorchPasses();
|
||||
mlir::NPCOMP::registerTypingPasses();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
|
||||
torch.class_type @c1 {}
|
||||
torch.class_type @c2 {}
|
||||
|
||||
// expected-note @+1 {{see other root module here}}
|
||||
torch.nn_module {} : !torch.nn.Module<"c1">
|
||||
// expected-error @+1 {{found more than one root module (module that is not a child of any other module)}}
|
||||
torch.nn_module {} : !torch.nn.Module<"c2">
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{class type has more than one instance: the current TorchScript supported subset only allows single instances}}
|
||||
torch.class_type @child {}
|
||||
torch.class_type @parent {
|
||||
torch.attr "m1" : !torch.nn.Module<"child">
|
||||
torch.attr "m2" : !torch.nn.Module<"child">
|
||||
}
|
||||
|
||||
// expected-note @+1 {{see instance here}}
|
||||
%0 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||
// expected-note @+1 {{see instance here}}
|
||||
%1 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||
|
||||
%root = torch.nn_module {
|
||||
torch.slot "m1", %0 : !torch.nn.Module<"child">
|
||||
torch.slot "m2", %1 : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
torch.method "f", @f
|
||||
}
|
||||
// expected-error @+1 {{func argument at index 1 has uses that were not able to be converted}}
|
||||
func private @f(%arg0: !torch.nn.Module<"c">, %arg1: !torch.nn.Module<"c">) {
|
||||
// expected-note @+1 {{see user here}}
|
||||
%0 = basicpy.build_list %arg1 : (!torch.nn.Module<"c">) -> !basicpy.ListType
|
||||
return
|
||||
}
|
||||
|
||||
torch.nn_module {} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{reachable by multiple paths from root object: '<root>.m' and '<root>.m2'}}
|
||||
torch.class_type @child {
|
||||
torch.attr "float" : f64
|
||||
}
|
||||
torch.class_type @parent {
|
||||
torch.attr "m" : !torch.nn.Module<"child">
|
||||
torch.attr "m2" : !torch.nn.Module<"child">
|
||||
|
||||
}
|
||||
|
||||
%c42 = std.constant 42.0 : f64
|
||||
%child = torch.nn_module {
|
||||
torch.slot "float", %c42 : f64
|
||||
} : !torch.nn.Module<"child">
|
||||
%parent = torch.nn_module {
|
||||
torch.slot "m", %child : !torch.nn.Module<"child">
|
||||
torch.slot "m2", %child : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
|
@ -0,0 +1,39 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "float" : f64
|
||||
torch.method "test_get", @test_get
|
||||
torch.method "test_set", @test_set
|
||||
torch.method "test_call", @test_call
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_get() -> f64 {
|
||||
// CHECK: %[[V:.*]] = torch.global_slot.get @float : f64
|
||||
// CHECK: return %[[V]] : f64
|
||||
func private @test_get(%arg0: !torch.nn.Module<"c">) -> f64 {
|
||||
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64
|
||||
return %0 : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_set(
|
||||
// CHECK-SAME: %[[A:.*]]: f64) {
|
||||
// CHECK: torch.global_slot.set @float = %[[A]] : f64
|
||||
// CHECK: return
|
||||
func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: f64) {
|
||||
torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, f64
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_call(
|
||||
// CHECK-SAME: %[[A:.*]]: f64) -> f64 {
|
||||
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (f64) -> f64
|
||||
// CHECK: return %[[V]] : f64
|
||||
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
|
||||
%0 = torch.prim.CallMethod %arg0["test_call"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64
|
||||
return %0 : f64
|
||||
}
|
||||
|
||||
%c42 = std.constant 42.0 : f64
|
||||
torch.nn_module {
|
||||
torch.slot "float", %c42 : f64
|
||||
} : !torch.nn.Module<"c">
|
|
@ -0,0 +1,25 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Check that linkage names consist of the dotted path from the root.
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @m.float : f64
|
||||
|
||||
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||
// CHECK: %[[C42:.*]] = constant 4.200000e+01 : f64
|
||||
// CHECK: torch.global_slot.set @m.float = %[[C42]] : f64
|
||||
// CHECK: return
|
||||
|
||||
torch.class_type @child {
|
||||
torch.attr "float" : f64
|
||||
}
|
||||
torch.class_type @parent {
|
||||
torch.attr "m" : !torch.nn.Module<"child">
|
||||
}
|
||||
|
||||
%c42 = std.constant 42.0 : f64
|
||||
%child = torch.nn_module {
|
||||
torch.slot "float", %c42 : f64
|
||||
} : !torch.nn.Module<"child">
|
||||
%parent = torch.nn_module {
|
||||
torch.slot "m", %child : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
|
@ -0,0 +1,63 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Basic case.
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType
|
||||
// CHECK: torch.global_slot @i : i64
|
||||
// CHECK: torch.global_slot @f : f64
|
||||
// CHECK: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
|
||||
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||
// CHECK: %[[CB:.*]] = basicpy.bool_constant true
|
||||
// CHECK: torch.global_slot.set @b = %[[CB]] : !basicpy.BoolType
|
||||
// CHECK: %[[CI:.*]] = basicpy.numeric_constant 3 : i64
|
||||
// CHECK: torch.global_slot.set @i = %[[CI]] : i64
|
||||
// CHECK: %[[CF:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
// CHECK: torch.global_slot.set @f = %[[CF]] : f64
|
||||
// CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
// CHECK: %[[CA:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: torch.global_slot.set @a = %[[CA]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
// CHECK: return
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "a" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
||||
%bool_true = basicpy.bool_constant true
|
||||
%i = basicpy.numeric_constant 3 : i64
|
||||
%f = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.nn_module {
|
||||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %i : i64
|
||||
torch.slot "f", %f : f64
|
||||
torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
||||
// Same SSA value used as initializer for multiple slots.
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @b1 : !basicpy.BoolType
|
||||
// CHECK-LABEL: torch.global_slot @b2 : !basicpy.BoolType
|
||||
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||
// CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
// CHECK: torch.global_slot.set @b1 = %[[TRUE]] : !basicpy.BoolType
|
||||
// CHECK: torch.global_slot.set @b2 = %[[TRUE]] : !basicpy.BoolType
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "b1" : !basicpy.BoolType
|
||||
torch.attr "b2" : !basicpy.BoolType
|
||||
}
|
||||
|
||||
%bool_true = basicpy.bool_constant true
|
||||
torch.nn_module {
|
||||
torch.slot "b1", %bool_true : !basicpy.BoolType
|
||||
torch.slot "b2", %bool_true : !basicpy.BoolType
|
||||
} : !torch.nn.Module<"c">
|
|
@ -2,14 +2,98 @@
|
|||
|
||||
// -----
|
||||
|
||||
torch.nn_module {
|
||||
// expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}}
|
||||
torch.class_type @c {}
|
||||
%0 = torch.nn_module {
|
||||
// expected-error @+1 {{'func' op is not allowed inside 'torch.nn_module'}}
|
||||
func @f()
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {}
|
||||
%c0 = constant 0 : i64
|
||||
// expected-error @+1 {{number of 'torch.slot's in a 'torch.nn_module' must match number of 'torch.attr's in the corresponding 'torch.class_type'}}
|
||||
%0 = torch.nn_module {
|
||||
torch.slot "f", %c0 : i64
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-note @+1 {{see torch.attr at corresponding index 0 here}}
|
||||
torch.attr "g" : i64
|
||||
}
|
||||
%c0 = constant 0 : i64
|
||||
%0 = torch.nn_module {
|
||||
// expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : i64'}}
|
||||
torch.slot "f", %c0 : i64
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{'func' op is not allowed inside `torch.class_type`}}
|
||||
func @f()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.nn_module {
|
||||
// expected-error @+1 {{'invalidSym' does not reference a valid function}}
|
||||
// expected-error @+1 {{has duplicate attr/method with name 'a'}}
|
||||
torch.class_type @c {
|
||||
// expected-note @+1 {{see first conflicting attr/method here}}
|
||||
torch.attr "a" : i64
|
||||
// expected-note @+1 {{see second conflicting attr/method here}}
|
||||
torch.attr "a" : i64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{'@invalidSym' does not reference a valid function}}
|
||||
torch.method "f", @invalidSym
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{'@f' must reference a private function}}
|
||||
torch.method "f", @f
|
||||
}
|
||||
|
||||
func @f(%arg0: !torch.nn.Module<"c">) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{'@f' must reference a function that is defined (not merely declared)}}
|
||||
torch.method "f", @f
|
||||
}
|
||||
|
||||
func private @f(%arg0: !torch.nn.Module<"c">)
|
||||
|
||||
// -----
|
||||
|
||||
func private @f() {
|
||||
return
|
||||
}
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
|
||||
torch.method "f", @f
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @f(!torch.nn.Module<"other_c">) {
|
||||
return
|
||||
}
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
|
||||
torch.method "f", @f
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'a' does not reference a valid class type}}
|
||||
%m = torch.nn_module {} : !torch.nn.Module<"a">
|
||||
|
|
|
@ -14,16 +14,28 @@ func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
|
|||
%num = basicpy.numeric_constant 4.250000e+01 : f64
|
||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
func @f(%arg0: !torch.nn.Module) {
|
||||
%none = basicpy.singleton : !basicpy.NoneType
|
||||
func private @f(%arg0: !torch.nn.Module<"test">) {
|
||||
return
|
||||
}
|
||||
%submodule = torch.nn_module {}
|
||||
|
||||
torch.nn_module {
|
||||
torch.attr "b", %bool_true : !basicpy.BoolType
|
||||
torch.attr "i", %num3_i64 : i64
|
||||
torch.attr "f", %num : f64
|
||||
torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "submodule", %submodule : !torch.nn.Module
|
||||
torch.class_type @empty {}
|
||||
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
|
||||
|
||||
torch.class_type @test {
|
||||
torch.attr "b" : !basicpy.BoolType
|
||||
torch.attr "i" : i64
|
||||
torch.attr "f" : f64
|
||||
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||
torch.attr "ob" : !torch.optional<!basicpy.BoolType>
|
||||
torch.method "method", @f
|
||||
}
|
||||
torch.nn_module {
|
||||
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||
torch.slot "i", %num3_i64 : i64
|
||||
torch.slot "f", %num : f64
|
||||
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||
torch.slot "ob", %none : !basicpy.NoneType
|
||||
} : !torch.nn.Module<"test">
|
||||
|
|
Loading…
Reference in New Issue