diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index 80fe5f7ee..44d41de0e 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -101,7 +101,8 @@ public: private: MlirValue rawImportIValue(c10::IValue value); MlirValue importModule(torch::jit::Module jitModule); - void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody); + void importMethod(torch::jit::Function *function, MlirBlock classTypeBody); + void importClassType(c10::ClassType *classType); MlirBlock importBlock; MlirContext context; @@ -111,6 +112,12 @@ private: std::unordered_map valueMap; // Used to detect potentially aliasing tensors. std::unordered_set seenStorageImpls; + // The set of ClassType's that have already been imported. + // + // ClassType's are referenced via their `classType->name()->qualifiedName()` + // string (as an MLIR symbol name) so we don't need to keep a map associating + // them with the MlirOperation that they import into. + std::unordered_set classTypes; // The stack of attribute names we have traversed to reach the current IValue. // Used for diagnostics. std::vector attributeNameStack; @@ -128,16 +135,25 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { // TODO: Can we do better? MlirLocation loc = mlirLocationUnknownGet(context); - MlirOperation nnModule = - createMlirOperation("torch.nn_module", loc, - npcompNnModuleTypeGet(context), mlirRegionCreate()); + c10::optional maybeName = currentModule.type()->name(); + if (!maybeName) { + throw std::invalid_argument("cannot import unnamed module"); + } + std::string moduleTypeName = maybeName->qualifiedName(); + + // Ensure the class type has been imported. + importClassType(currentModule.type().get()); + + MlirOperation nnModule = createMlirOperation( + "torch.nn_module", loc, + npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), + mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); if (!rootModuleName.has_value()) { - c10::optional maybeName = currentModule.type()->name(); - rootModuleName = maybeName ? maybeName->qualifiedName() : "unnamed module"; + rootModuleName = moduleTypeName; } const std::vector &slots = currentModule._ivalue()->slots(); @@ -151,7 +167,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { MlirValue slotValue = importIValue(slots[i]); // TODO: Is it necessary to track whether an attribute is a "parameter"? createMlirOperationAtEnd( - nnModuleBody, "torch.attr", loc, slotValue, + nnModuleBody, "torch.slot", loc, slotValue, toMlirNamedAttribute( "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName())))); @@ -162,10 +178,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) { rootModuleName = c10::nullopt; } - for (torch::jit::Function *function : currentModule.type()->methods()) { - importMethod(function, nnModuleBody); - } - createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc); mlirBlockInsertOwnedOperationBefore( importBlock, mlirBlockGetTerminator(importBlock), nnModule); @@ -262,7 +274,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { } void IValueImporter::importMethod(torch::jit::Function *function, - MlirBlock nnModuleBody) { + MlirBlock classTypeBody) { // We make an effort for the func op's symbol name to be useful for debugging, // but still clearly non-load-bearing. std::string symName = @@ -275,13 +287,50 @@ void IValueImporter::importMethod(torch::jit::Function *function, mlirBlockInsertOwnedOperationBefore( importBlock, mlirBlockGetTerminator(importBlock), func); createMlirOperationAtEnd( - nnModuleBody, "torch.method", mlirLocationUnknownGet(context), + classTypeBody, "torch.method", mlirLocationUnknownGet(context), toMlirNamedAttribute( "name", mlirStringAttrGet(context, toMlirStringRef(function->name()))), toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet( context, toMlirStringRef(symName)))); } + +void IValueImporter::importClassType(c10::ClassType *classType) { + if (!classTypes.insert(classType).second) { + return; + } + + // TODO: Can we do better? + MlirLocation loc = mlirLocationUnknownGet(context); + + MlirOperation op = createMlirOperationAtEnd( + importBlock, "torch.class_type", loc, mlirRegionCreate(), + toMlirNamedAttribute( + "sym_name", + mlirStringAttrGet( + context, toMlirStringRef(classType->name()->qualifiedName())))); + MlirRegion region = mlirOperationGetRegion(op, 0); + mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr)); + MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); + + for (const c10::ClassAttribute &classAttribute : classType->getAttributes()) { + createMlirOperationAtEnd( + classTypeBody, "torch.attr", loc, + toMlirNamedAttribute( + "name", mlirStringAttrGet( + context, toMlirStringRef(classAttribute.getName()))), + toMlirNamedAttribute("type", + mlirTypeAttrGet(typeMapper.mapFromTorchType( + loc, classAttribute.getType())))); + } + + for (torch::jit::Function *function : classType->methods()) { + importMethod(function, classTypeBody); + } + + createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc); +} + void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context) { // When debugging module importing, it can be useful to dump as so: diff --git a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp index 8167df8ee..a72656803 100644 --- a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp +++ b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp @@ -106,11 +106,19 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType); } case TypeKind::ClassType: { - return npcompNnModuleTypeGet(context); + auto maybeName = torchType->cast()->name(); + return npcompNnModuleTypeGet( + context, toMlirStringRef(maybeName ? maybeName->qualifiedName() + : "unnamed class")); } case TypeKind::FloatType: { return mlirF64TypeGet(context); } + case TypeKind::OptionalType: { + return npcompOptionalTypeGet( + mapFromTorchType( + loc, torchType->cast()->getElementType())); + } case TypeKind::IntType: { return mlirIntegerTypeGet(context, 64); } @@ -120,6 +128,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, case TypeKind::BoolType: { return npcompBoolTypeGet(context); } + case TypeKind::ListType: { + // TODO: Don't lose the element type information. + return npcompListTypeGet(context); + } default: { std::stringstream message; message << "unable to map Torch type " << *torchType << " to MLIR type"; diff --git a/frontends/pytorch/test/module_import/list.py b/frontends/pytorch/test/module_import/list.py index 57fa488c8..f78bec4ff 100644 --- a/frontends/pytorch/test/module_import/list.py +++ b/frontends/pytorch/test/module_import/list.py @@ -15,13 +15,16 @@ class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.l = [1, 2] - +# CHECK: torch.class_type @[[CLASSTYPE:.*]] { +# TODO: Don't lose element type. +# CHECK: torch.attr "l" : !basicpy.ListType +# CHECK: } # CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64 # CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64 # CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType # CHECK: torch.nn_module { -# CHECK: torch.attr "l", %[[LIST]] : !basicpy.ListType -# CHECK: } +# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> test_module = TestModule() diff --git a/frontends/pytorch/test/module_import/methods-debuggable-ir.py b/frontends/pytorch/test/module_import/methods-debuggable-ir.py index a39041325..8817a5afc 100644 --- a/frontends/pytorch/test/module_import/methods-debuggable-ir.py +++ b/frontends/pytorch/test/module_import/methods-debuggable-ir.py @@ -32,8 +32,8 @@ class TestModule(torch.nn.Module): # the case that the name is `__main__` Torch replaces it with `__torch__` to # avoid collisions. -# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward -# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward +# CHECK-DAG: func private @__npcomp_priv_fn.__torch__.TestModule.forward +# CHECK=DAG: func private @__npcomp_priv_fn.__torch__.Submodule.forward test_module = TestModule() diff --git a/frontends/pytorch/test/module_import/methods.py b/frontends/pytorch/test/module_import/methods.py index 32dfc219d..f070ecd92 100644 --- a/frontends/pytorch/test/module_import/methods.py +++ b/frontends/pytorch/test/module_import/methods.py @@ -19,17 +19,22 @@ class TestModule(torch.nn.Module): # The symbol name of the function is NOT load-bearing and cannot be relied upon. +# CHECK-LABEL: torch.class_type +# CHECK-SAME: @[[CLASSTYPE:.*]] { +# CHECK: torch.method "forward", @[[SYMNAME:.*]] +# CHECK: } + + # CHECK-LABEL: func private -# CHECK-SAME: @[[SYMNAME:.*]]( -# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module, +# CHECK-SAME: @[[SYMNAME]]( +# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"[[CLASSTYPE]]">, # CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { # CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]] # CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: %[[ROOT:.*]] = torch.nn_module { -# CHECK: torch.method "forward", @[[SYMNAME]] -# CHECK: } +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> test_module = TestModule() diff --git a/frontends/pytorch/test/module_import/object-identity-torch-bug.py b/frontends/pytorch/test/module_import/object-identity-torch-bug.py index f6f405c60..fd310a49e 100644 --- a/frontends/pytorch/test/module_import/object-identity-torch-bug.py +++ b/frontends/pytorch/test/module_import/object-identity-torch-bug.py @@ -24,8 +24,8 @@ class TestModule(torch.nn.Module): # CHECK: %[[L2:.*]] = basicpy.build_list # CHECK: %[[L1:.*]] = basicpy.build_list # CHECK: torch.nn_module { - # CHECK: torch.attr "l2", %[[L2]] - # CHECK: torch.attr "l1", %[[L1]] + # CHECK: torch.slot "l2", %[[L2]] + # CHECK: torch.slot "l1", %[[L1]] self.l2 = self.l1 = [1] # This can be uncommented when the graph importer supports it. diff --git a/frontends/pytorch/test/module_import/object-identity.py b/frontends/pytorch/test/module_import/object-identity.py index d5448a27d..dd930c0ab 100644 --- a/frontends/pytorch/test/module_import/object-identity.py +++ b/frontends/pytorch/test/module_import/object-identity.py @@ -16,8 +16,8 @@ class TestModule(torch.nn.Module): super().__init__() # CHECK: %[[A:.*]] = numpy.create_array_from_tensor # CHECK: torch.nn_module { - # CHECK: torch.attr "t1", %[[A]] - # CHECK: torch.attr "t2", %[[A]] + # CHECK: torch.slot "t1", %[[A]] + # CHECK: torch.slot "t2", %[[A]] self.t1 = self.t2 = torch.tensor([10., 20.]) diff --git a/frontends/pytorch/test/module_import/prim.py b/frontends/pytorch/test/module_import/prim.py index 29ae5f402..578805238 100644 --- a/frontends/pytorch/test/module_import/prim.py +++ b/frontends/pytorch/test/module_import/prim.py @@ -18,7 +18,7 @@ class TestModule(torch.nn.Module): self.t2 = torch.ones(1) # CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}( - # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module) -> !basicpy.NoneType { + # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !basicpy.NoneType { def forward(self): # CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"] # CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]] @@ -26,7 +26,7 @@ class TestModule(torch.nn.Module): # CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}}) self.callee(self.t1, self.t2) # CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}( - # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module, + # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">, # CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype> def callee(self, x, y): diff --git a/frontends/pytorch/test/module_import/primitives.py b/frontends/pytorch/test/module_import/primitives.py index b401086ad..f505787da 100644 --- a/frontends/pytorch/test/module_import/primitives.py +++ b/frontends/pytorch/test/module_import/primitives.py @@ -17,14 +17,20 @@ class TestModule(torch.nn.Module): self.i = 3 self.f = 42.5 +# CHECK: torch.class_type @[[CLASSTYPE:.*]] { +# CHECK: torch.attr "training" : !basicpy.BoolType +# CHECK: torch.attr "i" : i64 +# CHECK: torch.attr "f" : f64 +# CHECK: } # CHECK: %[[TRUE:.*]] = basicpy.bool_constant true # CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64 # CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64 # CHECK: %[[MODULE:.*]] = torch.nn_module { # Note: for some reason, Torch always adds a "training" property to all modules. -# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType -# CHECK: torch.attr "i", %[[N3]] : i64 -# CHECK: torch.attr "f", %[[N42]] : f64 +# CHECK: torch.slot "training", %[[TRUE]] : !basicpy.BoolType +# CHECK: torch.slot "i", %[[N3]] : i64 +# CHECK: torch.slot "f", %[[N42]] : f64 +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE:.*]]"> test_module = TestModule() diff --git a/frontends/pytorch/test/module_import/submodules-select.py b/frontends/pytorch/test/module_import/submodules-select.py index c6ff7ee1c..088a3c388 100644 --- a/frontends/pytorch/test/module_import/submodules-select.py +++ b/frontends/pytorch/test/module_import/submodules-select.py @@ -15,6 +15,8 @@ class Submodule(torch.nn.Module): def __init__(self, n): super().__init__() self.n = n + def forward(self): + return self.n class TestModule(torch.nn.Module): def __init__(self): @@ -27,9 +29,9 @@ class TestModule(torch.nn.Module): # Modules with the same class can be selected between. # CHECK: %[[MOD:.*]] = scf.if s = self.s1 if b else self.s2 - # CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"] - # CHECK: return %[[N]] - return s.n + # CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] () + # CHECK: return %[[N]] + return s.forward() test_module = TestModule() diff --git a/frontends/pytorch/test/module_import/submodules.py b/frontends/pytorch/test/module_import/submodules.py index 3a57c085d..94cfbf9ae 100644 --- a/frontends/pytorch/test/module_import/submodules.py +++ b/frontends/pytorch/test/module_import/submodules.py @@ -26,20 +26,20 @@ class TestModule(torch.nn.Module): # CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64 # CHECK: %[[S0:.*]] = torch.nn_module { -# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType -# CHECK: torch.attr "n", %[[N0]] : i64 +# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType +# CHECK: torch.slot "n", %[[N0]] : i64 # CHECK: } # CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64 # CHECK: %[[S1:.*]] = torch.nn_module { -# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType -# CHECK: torch.attr "n", %[[N1]] : i64 +# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType +# CHECK: torch.slot "n", %[[N1]] : i64 # CHECK: } # CHECK: %[[ROOT:.*]] = torch.nn_module { -# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType -# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module -# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module +# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType +# CHECK: torch.slot "s0", %[[S0]] : !torch.nn.Module +# CHECK: torch.slot "s1", %[[S1]] : !torch.nn.Module # CHECK: } diff --git a/frontends/pytorch/test/module_import/tensors.py b/frontends/pytorch/test/module_import/tensors.py index af12f4c3d..37a1a3a30 100644 --- a/frontends/pytorch/test/module_import/tensors.py +++ b/frontends/pytorch/test/module_import/tensors.py @@ -23,8 +23,8 @@ class TestModule(torch.nn.Module): # CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32> # CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> # CHECK: %[[ROOT:.*]] = torch.nn_module { -# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype> -# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype> +# CHECK: torch.slot "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype> +# CHECK: torch.slot "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: } diff --git a/include/npcomp-c/Types.h b/include/npcomp-c/Types.h index 0a8b679da..e194a71be 100644 --- a/include/npcomp-c/Types.h +++ b/include/npcomp-c/Types.h @@ -124,8 +124,18 @@ MlirType npcompTupleTypeGet(MlirContext context); /** Checks whether the given type is a torch.nn.Module type */ int npcompTypeIsANnModule(MlirType t); -/** Gets the singleton torch.nn.Module type. */ -MlirType npcompNnModuleTypeGet(MlirContext context); +/** Gets the !torch.nn.Module type of the specified class. */ +MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className); + +/*============================================================================*/ +/* torch.optional type. */ +/*============================================================================*/ + +/** Checks whether the given type is a !torch.optional type */ +int npcompTypeIsAOptional(MlirType t); + +/** Gets the !torch.optional type with subtype T. */ +MlirType npcompOptionalTypeGet(MlirType containedType); #ifdef __cplusplus } diff --git a/include/npcomp/Dialect/Torch/CMakeLists.txt b/include/npcomp/Dialect/Torch/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/include/npcomp/Dialect/Torch/CMakeLists.txt +++ b/include/npcomp/Dialect/Torch/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index d6386f8ea..96ba2c3a7 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -49,10 +49,11 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [ } //===----------------------------------------------------------------------===// -// TorchScript modeling ops. +// TorchScript `torch.nn.Module` object instantiation ops. //===----------------------------------------------------------------------===// def Torch_NnModuleOp : Torch_Op<"nn_module", [ + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> { let summary = "Constructs a torch.nn.Module"; let description = [{ @@ -65,15 +66,19 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [ Example: ```mlir - %2 = torch.nn_module { - torch.attr "b", %bool_true : !basicpy.BoolType - torch.attr "i", %num3_i64 : i64 - torch.attr "f", %num : f64 - torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype> - torch.attr "submodule", %1 : !torch.nn.Module - torch.method "method", @f - } + %2 = torch.nn_module { + torch.slot "b", %bool_true : !basicpy.BoolType + torch.slot "i", %num3_i64 : i64 + torch.slot "f", %num : f64 + torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype> + torch.slot "submodule", %1 : !torch.nn.Module + } : !torch.nn.Module<"my_class_name"> ``` + + This op is tightly coupled to the `torch.class_type` op named in the + `!torch.nn.Module<"my_class_name">` type. Each slot must match precisely + with the corresponding `torch.attr` in the `torch.class_type`. + See the documentation for `torch.class_type` for information. }]; let arguments = (ins); @@ -81,7 +86,11 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [ let regions = (region SizedRegion<1>:$region); let verifier = "return ::verify(*this);"; - let assemblyFormat = "$region attr-dict"; + let assemblyFormat = "$region attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + StringRef getClassName() { return getType().getClassName(); } + }]; } def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator, @@ -94,13 +103,13 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator, let assemblyFormat = "attr-dict"; } -def Torch_AttrOp : Torch_Op<"attr", [ +def Torch_SlotOp : Torch_Op<"slot", [ HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> { - let summary = "Define an attribute of a torch.nn.Module"; + let summary = "Define the value of a slot of a torch.nn.Module"; let description = [{ - This op declaratively specifies that the parent torch.nn_module has an - attribute `name` with value `value`, which is allowed to be an arbitrary - Torch-compatible SSA value, including other torch.nn.Module's. + This op specifies that the initial value of the slot `name` of the + parent torch.nn_module should be `value`, which is allowed to be an + arbitrary Torch-compatible SSA value, including other !torch.nn.Module's. }]; let arguments = (ins StrAttr:$name, AnyTorchType:$value); @@ -111,13 +120,73 @@ def Torch_AttrOp : Torch_Op<"attr", [ }]; } +//===----------------------------------------------------------------------===// +// Modeling of TorchScript class types +//===----------------------------------------------------------------------===// + +def Torch_ClassTypeOp : Torch_Op<"class_type", [ + Symbol, + SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> { + let summary = "Constructs a torch.ClassType"; + let description = [{ + Declares a class type. Class types are the types used to describe + TorchScript `torch.nn.Module`'s. The terminology "class type" is for + consistency with TorchScript (a better name in our context might be + "nn module subtype"). The `syn_name` of this op is the same string + as in the `!torch.nn.Module<"...">` type. + + Example: + + ```mlir + // A simple empty torch.class_type, with corresponding torch.nn_module. + torch.class_type @empty {} + %submodule = torch.nn_module {} : !torch.nn.Module<"empty"> + + // A class type with many members. + torch.class_type @test { + torch.attr "b" : !basicpy.BoolType + torch.attr "i" : i64 + torch.attr "f" : f64 + torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype> + torch.attr "submodule" : !torch.nn.Module<"empty"> + torch.method "method", @f + } + torch.nn_module { + // These must match the order and names in the `torch.class_type`. + torch.slot "b", %bool_true : !basicpy.BoolType + torch.slot "i", %num3_i64 : i64 + torch.slot "f", %num : f64 + torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype> + torch.slot "submodule", %submodule : !torch.nn.Module<"empty"> + } : !torch.nn.Module<"test"> + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name); + let results = (outs); + let regions = (region SizedRegion<1>:$region); + let verifier = "return ::verify(*this);"; + + let assemblyFormat = "$sym_name $region attr-dict"; +} + +def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator, + HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> { + let summary = "Implicit terminator for torch.class_type"; + + let arguments = (ins); + let results = (outs); + + let assemblyFormat = "attr-dict"; +} + def Torch_MethodOp : Torch_Op<"method", [ - HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">, + HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">, DeclareOpInterfaceMethods ]> { - let summary = "Define a method of a torch.nn.Module"; + let summary = "Declare a method of a torch.class_type"; let description = [{ - This op declaratively specifies that the parent torch.nn_module has a + This op declaratively specifies that the parent torch.class_type has a method `name` which calls `function`. `function` is an unbound function. That is, it explicitly takes the torch.nn.Module as a parameter (no implicit "self" object). @@ -131,6 +200,88 @@ def Torch_MethodOp : Torch_Op<"method", [ }]; } +def Torch_AttrOp : Torch_Op<"attr", [ + HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp"> + ]> { + let summary = "Declare an attribute of a torch.class_type"; + let description = [{ + This op declaratively specifies that torch.nn.Module's of the parent + torch.class_type must have an attribute `name` of type `type`. + }]; + + let arguments = (ins StrAttr:$name, TypeAttr:$type); + let results = (outs); + + let assemblyFormat = [{ + $name `:` $type attr-dict + }]; +} + +//===----------------------------------------------------------------------===// +// Global slot ops +//===----------------------------------------------------------------------===// +// TODO: Should these be in a separate dialect? +// At this point, they are fairly specific to torch types, but their get/set +// semantics follow Python. +//===----------------------------------------------------------------------===// + +def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ + Symbol, + IsolatedFromAbove, + ]> { + let summary = "A slot with global storage"; + let description = [{ + Represents a slot with global storage. The slot semantics are the same + as Python's: getting or setting a slot is done by object identity. + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$typeBound); + let results = (outs); + + + let assemblyFormat = [{ + $sym_name attr-dict `:` $typeBound + }]; + + let extraClassDeclaration = [{ + // The name of the function, which, for semantic correctness, must be called + // exactly once and this call must be done before any other calls into + // the module. + // TODO: Avoid load-bearing names. + // We could replace this with an op that marks the function as initializer. + static constexpr StringRef getGlobalSlotInitializerFuncName() { + return "__torch_global_slot_initializer"; + } + }]; +} + +def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> { + let summary = "Get the value stored in a torch.global_slot"; + + let arguments = (ins + FlatSymbolRefAttr:$slot + ); + let results = (outs AnyTorchType:$result); + + let assemblyFormat = [{ + $slot attr-dict `:` type($result) + }]; +} + +def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> { + let summary = "Set the value stored in a torch.global_slot"; + + let arguments = (ins + FlatSymbolRefAttr:$slot, + AnyTorchType:$value + ); + let results = (outs); + + let assemblyFormat = [{ + $slot `=` $value attr-dict `:` type($value) + }]; +} + //===----------------------------------------------------------------------===// // TorchScript `prim::` ops. //===----------------------------------------------------------------------===// @@ -142,7 +293,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> { let results = (outs AnyTorchType:$result); let assemblyFormat = [{ - $receiver `[` $name `]` attr-dict `:` type($result) + $receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result) }]; } @@ -157,7 +308,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> { let results = (outs); let assemblyFormat = [{ - $receiver `[` $name `]` `=` $value attr-dict `:` type($value) + $receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value) }]; } @@ -172,7 +323,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> { let results = (outs AnyTorchType:$result); let assemblyFormat = [{ - $receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($operands) `->` type($result) + $receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result) }]; } diff --git a/include/npcomp/Dialect/Torch/IR/TorchTypes.td b/include/npcomp/Dialect/Torch/IR/TorchTypes.td index 0a72d2644..9dab6fc41 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchTypes.td +++ b/include/npcomp/Dialect/Torch/IR/TorchTypes.td @@ -22,9 +22,62 @@ class Torch_Type : TypeDef { let summary = "torch.nn.Module"; let description = [{ + Represents an instance of a `torch.nn.Module` with the given `className`. + }]; + let parameters = (ins StringRefParameter<"class name">:$className); + + let printer = [{ + $_printer << "nn.Module<\""; + llvm::printEscapedString(getImpl()->className, $_printer.getStream()); + $_printer << "\">"; + }]; + + let parser = [{ + if (parser.parseLess()) + return Type(); + StringRef className; + if ($_parser.parseOptionalString(&className)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, className); }]; } +// TODO: It feels like this should be something more general. +// However, to do that, we need to agree on construction operations +// and the valid MLIR representations of the "None" state. +// +// For now, we only need it as a stand-in type to allow importing +// the `_is_full_backward_hook` optional bool type that Torch puts on +// all classes. +def Torch_OptionalType : Torch_Type<"Optional", "optional"> { + let summary = "!torch.optional"; + let description = [{ + }]; + let parameters = (ins "::mlir::Type":$containedType); + + let printer = [{ + $_printer << "optional<" << getImpl()->containedType << ">"; + }]; + + let parser = [{ + if (parser.parseLess()) + return Type(); + Type containedType; + if ($_parser.parseType(containedType)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, containedType); + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{ + return Base::get(containedType.getContext(), containedType); + }]> + ]; +} //===----------------------------------------------------------------------===// // Type predicates diff --git a/include/npcomp/Dialect/Torch/Transforms/CMakeLists.txt b/include/npcomp/Dialect/Torch/Transforms/CMakeLists.txt new file mode 100644 index 000000000..a55778059 --- /dev/null +++ b/include/npcomp/Dialect/Torch/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(NPCOMPTorchPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc NPCOMPTorchTransforms ./) diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.h b/include/npcomp/Dialect/Torch/Transforms/Passes.h new file mode 100644 index 000000000..12016e5fa --- /dev/null +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.h @@ -0,0 +1,30 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H +#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +#include + +namespace mlir { +namespace NPCOMP { +namespace Torch { + +std::unique_ptr> createGlobalizeObjectGraphPass(); + +} // namespace Torch + +/// Registers all Torch transformation passes. +void registerTorchPasses(); + +} // namespace NPCOMP +} // namespace mlir + +#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.td b/include/npcomp/Dialect/Torch/Transforms/Passes.td new file mode 100644 index 000000000..ab427fd37 --- /dev/null +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.td @@ -0,0 +1,65 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_TORCH_PASSES +#define NPCOMP_TORCH_PASSES + +include "mlir/Pass/PassBase.td" + +def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> { + let summary = "Converts TorchScript object graphs to a globalized form"; + let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()"; + let description = [{ + This pass converts a subset of possible TorchScript modules into a + more restrictive lower-level form that strips away the need to be + concerned with instances of !torch.nn.Module<...> type. Specifically, + the object graph is flattened into a set of discrete globals + (`torch.global_slot`) that hold the program state. + + The overarching goal is for a strict correspondence between the original + `torch.nn.Module` (call it `root`) that the user `torch.jit.script`'ed, and + the public interface of the resulting MLIR module. Specifically: + - The call `root.encoder.forward(...)` in Python corresponds to invoking + the `func @encoder.forward` on the resulting MLIR module. + - The data member access `root.decoder.ids_to_strings_table` in Python + corresponds to accessing the + `torch.global_slot @decoder.ids_to_strings_table` on the resulting + MLIR module. + In effect, the entire MLIR module corresponds to an instance of the `root` + object. This matches with the intuitive behavior desired for deployment: + When the MLIR module (or, more likely, a compiled artifact derived from it) + is loaded in a deployed environment, it is equivalent to recreating the + original `root` object. + + This pass performs a complete change of the externally visible calling + convention of the MLIR module for a graph of objects and methods to a + fixed set of globals and functions. + + Of course, only a subset of programs can be transformed, and this pass fails + with an error if the conditions are violated. + + Specifically, the restrictions are: + - There must be a unique torch.nn_module that is not the value of a slot + of any other torch.nn_module + - Rationale: Allows us to have a notion of a unique "root" op, which is + used to define linkage. This also matches how TorchScript imports in + practice (`torch.jit.script` imports a single root object). + - There must be exactly one instance of each torch.class_type. Equivalently, + Every torch.nn_module must have a distinct type. + - Rationale: This guarantee precludes things like selecting between + multiple modules dynamically at runtime, which would require indirecting + between the separate storage of each instance. + - All torch.nn_module's must be reachable by a unique path from the root + - Rationale: Eliminates possibility of potentially exponential number of + paths. Or worse, infinite number of paths when considering cyclic + object graphs. Also as of Feb 2021, TorchScript won't import into + this form (it has a bug related to the identity of submodules). + }]; +} + +#endif // NPCOMP_TORCH_PASSES diff --git a/lib/CAPI/Types.cpp b/lib/CAPI/Types.cpp index b24660a96..6aa626619 100644 --- a/lib/CAPI/Types.cpp +++ b/lib/CAPI/Types.cpp @@ -156,7 +156,21 @@ int npcompTypeIsANnModule(MlirType t) { return unwrap(t).isa(); } -/** Gets the singleton torch.nn.Module type. */ -MlirType npcompNnModuleTypeGet(MlirContext context) { - return wrap(Torch::NnModuleType::get(unwrap(context))); +/** Gets the torch.nn.Module type of the specified class. */ +MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) { + return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); +} + +/*============================================================================*/ +/* torch.optional type. */ +/*============================================================================*/ + +/** Checks whether the given type is a !torch.optional type */ +int npcompTypeIsAOptional(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the !torch.optional type with subtype T. */ +MlirType npcompOptionalTypeGet(MlirType containedType) { + return wrap(Torch::OptionalType::get(unwrap(containedType))); } diff --git a/lib/Dialect/Torch/CMakeLists.txt b/lib/Dialect/Torch/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/lib/Dialect/Torch/CMakeLists.txt +++ b/lib/Dialect/Torch/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 89c903898..08e93b570 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/DialectImplementation.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchTypes.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 017b3b66b..92a749de1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -14,6 +14,7 @@ #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h" +#include "llvm/ADT/StringMap.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -49,8 +50,23 @@ KernelMetadata KernelCallOp::getTorchKernelMetadata() { LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto func = symbolTable.lookupNearestSymbolFrom(*this, function()); if (!func) - return emitError() << "'" << function() + return emitError() << "'@" << function() << "' does not reference a valid function"; + if (func.getVisibility() != SymbolTable::Visibility::Private) + return emitError() << "'@" << function() + << "' must reference a private function"; + if (func.isDeclaration()) + return emitError() << "'@" << function() + << "' must reference a function that is defined (not " + "merely declared)"; + auto expectedReceiverArgType = NnModuleType::get( + getContext(), getOperation()->getParentOfType().getName()); + if (func.getType().getNumInputs() == 0 || + func.getType().getInput(0) != expectedReceiverArgType) { + return emitError() << "the referenced function '" << function() + << "' must have a first argument of type " + << expectedReceiverArgType; + } return success(); } @@ -60,8 +76,82 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { static LogicalResult verify(NnModuleOp op) { for (Operation &child : *op.getBody()) - if (!isa(&child)) - return child.emitOpError() << "is not allowed inside `torch.nn_module`"; + if (!isa(&child)) + return child.emitOpError() << "is not allowed inside 'torch.nn_module'"; + return success(); +} + +// PyTorch has a well-developed notion of subtyping. +// +// This is a restricted subset of it. +// +// TODO: Flesh this out. +bool isValidSubtype(Type subtype, Type type) { + if (subtype == type) + return true; + if (auto optional = type.dyn_cast()) + return subtype == optional.getContainedType() || + subtype.isa(); + return false; +} + +LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto classType = + symbolTable.lookupNearestSymbolFrom(*this, getClassName()); + if (!classType) + return emitError() << "'" << getClassName() + << "' does not reference a valid class type"; + + auto attrs = llvm::to_vector<6>(getBody()->getOps()); + auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps()); + if (attrs.size() != attrDefs.size()) + return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must " + "match number of 'torch.attr's in " + "the corresponding 'torch.class_type'"; + for (int i = 0, e = attrs.size(); i != e; i++) { + SlotOp attr = attrs[i]; + AttrOp attrDef = attrDefs[i]; + if (!isValidSubtype(attr.value().getType(), attrDef.type()) || + attr.name() != attrDef.name()) { + return attr.emitOpError() + .append("is expected to match type and name of '", + attrDef.getOperation(), "'") + .attachNote(attrDef.getLoc()) + .append("see torch.attr at corresponding index ", i, " here"); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ClassTypeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ClassTypeOp op) { + llvm::StringMap namesToOps; + for (Operation &child : op.getBody()->without_terminator()) { + if (!isa(&child)) + return child.emitOpError() << "is not allowed inside `torch.class_type`"; + StringRef name; + if (auto attr = dyn_cast(child)) + name = attr.name(); + else + name = cast(child).name(); + auto itAndWasInserted = namesToOps.insert({name, &child}); + auto it = itAndWasInserted.first; + bool wasInserted = itAndWasInserted.second; + if (!wasInserted) { + auto diag = op.emitOpError().append( + "has duplicate attr/method with name '", name, "'"); + diag.attachNote(it->second->getLoc()) + .append("see first conflicting attr/method here"); + diag.attachNote(child.getLoc()) + .append("see second conflicting attr/method here"); + return failure(); + } + } + return success(); } diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt new file mode 100644 index 000000000..b49e56794 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -0,0 +1,19 @@ +add_npcomp_conversion_library(NPCOMPTorchPasses + Passes.cpp + GlobalizeObjectGraph.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms + + DEPENDS + NPCOMPTorchPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + NPCOMPTorchDialect + NPCOMPBasicpyDialect +) diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp new file mode 100644 index 000000000..6ee787386 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -0,0 +1,340 @@ +//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" +#include "llvm/ADT/MapVector.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Torch; + +namespace { +// See the pass documentation for `torch-globalize-object-graph`. +class ObjectGraphGlobalizer { +public: + ObjectGraphGlobalizer(ModuleOp module); + LogicalResult globalizeObjectGraph(); + +private: + FailureOr findRootNnModule(); + LogicalResult checkSingleInstanceOfEachClass(); + LogicalResult recursivelyTraverseClassType(ClassTypeOp classType); + void createInitializerFunc(); + LogicalResult rewriteMethods(); + void removeObjectGraph(); + + ModuleOp module; + SymbolTable symbolTable; + OpBuilder globalBuilder; + // The stack of attribute names we have traversed during our recursive + // traversal of the class/object hierarchy. + // + // Linkage names are calculated based on the set of attribute names traversed + // from the root class/module in the program. + SmallVector nameStack; + + // Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot + // ops. However, usually it is better to keep a map keyed on an ClassTypeOp + // + attr name since frequently that is all one has access to and it + // would be tedious to scan the body of the ClassTypeOp for the torch.attr + // with the corresponding name. + using AttrOfClass = + std::pair; + // The initial value associated with an attribute of a class. + // Since we only allow a single instance of a class, this is equivalent to + // the initial value of the unique slot corresponding to that attr. + DenseMap slotInitialValues; + // The inverse map of `slotInitialValues`. + // Many attributes can have the same initial value, so the value type + // is a vector. + DenseMap> slotInitialValuesInverseMap; + + // The torch.global_slot corresponding to each torch.attr/torch.slot. + DenseMap globalSlotForAttr; + // The linkage name (value) for the function with symbol name (key). + DenseMap methodLinkageNames; + + // The set of class types that have already been processed. + // Used for diagnostics. + // The map value is the original path from the root that we found it at. + DenseMap seenClassTypes; +}; +} // namespace + +ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module) + : module(module), symbolTable(module), + globalBuilder(module.getBodyRegion()) {} + +LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() { + // We require there to be a unique root !torch.nn.Module. + FailureOr maybeRootNnModule = findRootNnModule(); + if (failed(maybeRootNnModule)) + return failure(); + NnModuleOp rootNnModule = *maybeRootNnModule; + if (!rootNnModule) + return module.emitError() + << "module does not contain a root torch.nn_module"; + + // We require one instance of each class. That is, there is a single + // torch.nn_module for each torch.class_type. + if (failed(checkSingleInstanceOfEachClass())) + return failure(); + + for (NnModuleOp nnModule : module.getOps()) { + auto classType = symbolTable.lookup(nnModule.getClassName()); + for (auto slot : nnModule.getOps()) { + AttrOfClass attrOfClass = {classType, slot.name()}; + slotInitialValues[attrOfClass] = slot.value(); + slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass); + } + } + + // Recursively traverse the class hierarchy, globalizing slots and + // tracking linkage names for methods. + auto rootClassType = + symbolTable.lookup(rootNnModule.getClassName()); + if (failed(recursivelyTraverseClassType(rootClassType))) + return failure(); + + // Move all slot initial values into an initializer func. + createInitializerFunc(); + + // Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod. + if (failed(rewriteMethods())) + return failure(); + + // Now that all we have finished converting to the new form, remove + // the original object graph. + removeObjectGraph(); + + return success(); +} + +FailureOr ObjectGraphGlobalizer::findRootNnModule() { + NnModuleOp rootNnModule; + for (NnModuleOp op : module.getOps()) { + if (!op.use_empty()) + continue; + if (rootNnModule) { + op.emitError() + .append("found more than one root module (module that is not a " + "child of any other module)") + .attachNote(rootNnModule.getLoc()) + .append("see other root module here"); + return failure(); + } + rootNnModule = op; + } + return rootNnModule; +} + +LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() { + llvm::MapVector> + classInstances; + for (NnModuleOp op : module.getOps()) { + auto classType = symbolTable.lookup(op.getClassName()); + classInstances[classType].push_back(op); + } + for (auto &p : classInstances) { + ClassTypeOp classType = cast(p.first); + ArrayRef instances = p.second; + if (instances.size() > 1) { + // TODO: Improve this diagnostic based on user use cases. + // This is a user-facing diagnostic that enforces key invariants to + // our TorchScript subset. + auto diag = classType.emitError( + "class type has more than one instance: the current TorchScript " + "supported subset only allows single instances"); + for (NnModuleOp instance : instances) { + diag.attachNote(instance.getLoc()) << "see instance here"; + } + return failure(); + } + } + return success(); +} + +LogicalResult +ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) { + std::string pathToClassFromRoot = llvm::join(nameStack, "."); + if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) { + return classType.emitError() + << "reachable by multiple paths from root object: '." + << seenClassTypes[classType] << "' and '." + << pathToClassFromRoot << "'"; + } + + // For each attr, create a global slot for it. + for (auto attr : classType.getOps()) { + nameStack.push_back(attr.name().str()); + if (auto type = attr.type().dyn_cast()) { + recursivelyTraverseClassType( + symbolTable.lookup(type.getClassName())); + } else { + auto linkageName = llvm::join(nameStack, "."); + auto globalSlot = globalBuilder.create( + attr->getLoc(), linkageName, TypeAttr::get(attr.type())); + AttrOfClass attrOfClass = {classType, attr.name()}; + assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end()); + globalSlotForAttr[attrOfClass] = globalSlot; + } + nameStack.pop_back(); + } + // For each method, track the linkage name it will eventually have. + for (auto method : classType.getOps()) { + nameStack.push_back(method.name().str()); + auto linkageName = llvm::join(nameStack, "."); + nameStack.pop_back(); + if (!methodLinkageNames.insert({method.function(), linkageName}).second) + method.emitError() << "unbound function shared by multiple methods"; + } + return success(); +} + +void ObjectGraphGlobalizer::createInitializerFunc() { + auto loc = module.getLoc(); + auto func = globalBuilder.create( + loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(), + globalBuilder.getFunctionType({}, {})); + OpBuilder builder(func.getContext()); + Block *body = builder.createBlock(&func.getBody()); + + SmallVector opsToMove; + for (Operation &op : llvm::make_early_inc_range(*module.getBody())) { + if (isa( + &op)) + continue; + op.moveBefore(body, body->end()); + for (Value result : llvm::make_early_inc_range(op.getResults())) { + auto it = slotInitialValuesInverseMap.find(result); + if (it == slotInitialValuesInverseMap.end()) + continue; + for (AttrOfClass attrOfClass : it->second) { + GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass]; + OpBuilder::atBlockEnd(body).create( + globalSlot.getLoc(), globalSlot.sym_name(), result); + } + } + } + + builder.create(loc); +} + +LogicalResult ObjectGraphGlobalizer::rewriteMethods() { + DenseMap linkageNames; + for (auto classType : module.getOps()) { + for (auto method : classType.getOps()) { + auto it = methodLinkageNames.find(method.function()); + if (it == methodLinkageNames.end()) + continue; + linkageNames[{classType, method.name()}] = it->second; + } + } + // We only handle a small subset of ops that conform with the set of + // assumptions that allow us to globalize the object graph. Anything that + // tries to treat modules as bona-fide objects and not just namespaces + // of methods with a single instance of the corresponding type just gets + // arbitrarily tricky to rewrite. E.g. what if the user creates a list + // of modules, or there is an scf.if selecting between modules, etc. + auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) { + if (auto primSetAttr = dyn_cast(op)) { + auto classType = symbolTable.lookup( + primSetAttr.receiver().getType().cast().getClassName()); + auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}]; + OpBuilder(primSetAttr) + .create(primSetAttr.getLoc(), globalSlot.sym_name(), + primSetAttr.value()); + primSetAttr.erase(); + } + if (auto primGetAttr = dyn_cast(op)) { + auto classType = symbolTable.lookup( + primGetAttr.receiver().getType().cast().getClassName()); + auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}]; + auto globalSlotGet = OpBuilder(primGetAttr) + .create(primGetAttr.getLoc(), + primGetAttr.getType(), + globalSlot.sym_name()); + primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation()); + primGetAttr.erase(); + } + if (auto primCallMethod = dyn_cast(op)) { + auto classType = symbolTable.lookup(primCallMethod.receiver() + .getType() + .cast() + .getClassName()); + StringRef linkageName = linkageNames[{classType, primCallMethod.name()}]; + auto call = OpBuilder(primCallMethod) + .create(primCallMethod.getLoc(), linkageName, + primCallMethod.getType(), + primCallMethod.operands()); + primCallMethod.replaceAllUsesWith(call); + primCallMethod.erase(); + } + }; + for (auto classType : module.getOps()) { + for (auto method : classType.getOps()) { + auto it = methodLinkageNames.find(method.function()); + if (it == methodLinkageNames.end()) + continue; + FuncOp func = symbolTable.lookup(method.function()); + func.setVisibility(SymbolTable::Visibility::Public); + func.setName(it->second); + func.walk(rewriteOpWithNnModuleTypeOperand); + SmallVector argsToErase; + for (auto arg : llvm::enumerate(func.getArguments())) { + if (!arg.value().getType().isa()) + continue; + if (!arg.value().use_empty()) { + // TODO: Improve this based on real user use cases. + // This is a diagnostic that users will hit if they do not conform to + // the supported subset of TorchScript. + auto diag = func.emitError().append( + "func argument at index ", arg.index(), + " has uses that were not able to be converted"); + for (Operation *user : arg.value().getUsers()) + diag.attachNote(user->getLoc()).append("see user here"); + return failure(); + } + argsToErase.push_back(arg.index()); + } + func.eraseArguments(argsToErase); + } + } + return success(); +} + +void ObjectGraphGlobalizer::removeObjectGraph() { + for (Operation &op : llvm::make_early_inc_range(*module.getBody())) { + if (isa(op)) + op.erase(); + } +} + +namespace { +class GlobalizeObjectGraphPass + : public GlobalizeObjectGraphBase { + void runOnOperation() override { + if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph())) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/PassDetail.h b/lib/Dialect/Torch/Transforms/PassDetail.h new file mode 100644 index 000000000..af5f9744f --- /dev/null +++ b/lib/Dialect/Torch/Transforms/PassDetail.h @@ -0,0 +1,25 @@ +//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H +#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace NPCOMP { +namespace Torch { + +#define GEN_PASS_CLASSES +#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc" + +} // namespace Torch +} // namespace NPCOMP +} // end namespace mlir + +#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp new file mode 100644 index 000000000..fa6594e53 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/Torch/Transforms/Passes.h" + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc" +} // end namespace + +void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index dd9a67021..231d740ac 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -21,6 +21,7 @@ #include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/Transforms/Passes.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Typing/Transforms/Passes.h" #include "npcomp/Conversion/Passes.h" @@ -47,5 +48,6 @@ void mlir::NPCOMP::registerAllPasses() { mlir::NPCOMP::registerNumpyPasses(); mlir::NPCOMP::registerTCFPasses(); mlir::NPCOMP::registerTCPPasses(); + mlir::NPCOMP::registerTorchPasses(); mlir::NPCOMP::registerTypingPasses(); } diff --git a/test/Dialect/Torch/globalize-object-graph-error.mlir b/test/Dialect/Torch/globalize-object-graph-error.mlir new file mode 100644 index 000000000..c659c0768 --- /dev/null +++ b/test/Dialect/Torch/globalize-object-graph-error.mlir @@ -0,0 +1,63 @@ +// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s + +torch.class_type @c1 {} +torch.class_type @c2 {} + +// expected-note @+1 {{see other root module here}} +torch.nn_module {} : !torch.nn.Module<"c1"> +// expected-error @+1 {{found more than one root module (module that is not a child of any other module)}} +torch.nn_module {} : !torch.nn.Module<"c2"> + +// ----- + +// expected-error @+1 {{class type has more than one instance: the current TorchScript supported subset only allows single instances}} +torch.class_type @child {} +torch.class_type @parent { + torch.attr "m1" : !torch.nn.Module<"child"> + torch.attr "m2" : !torch.nn.Module<"child"> +} + +// expected-note @+1 {{see instance here}} +%0 = torch.nn_module {} : !torch.nn.Module<"child"> +// expected-note @+1 {{see instance here}} +%1 = torch.nn_module {} : !torch.nn.Module<"child"> + +%root = torch.nn_module { + torch.slot "m1", %0 : !torch.nn.Module<"child"> + torch.slot "m2", %1 : !torch.nn.Module<"child"> +} : !torch.nn.Module<"parent"> + +// ----- + +torch.class_type @c { + torch.method "f", @f +} +// expected-error @+1 {{func argument at index 1 has uses that were not able to be converted}} +func private @f(%arg0: !torch.nn.Module<"c">, %arg1: !torch.nn.Module<"c">) { + // expected-note @+1 {{see user here}} + %0 = basicpy.build_list %arg1 : (!torch.nn.Module<"c">) -> !basicpy.ListType + return +} + +torch.nn_module {} : !torch.nn.Module<"c"> + +// ----- + +// expected-error @+1 {{reachable by multiple paths from root object: '.m' and '.m2'}} +torch.class_type @child { + torch.attr "float" : f64 +} +torch.class_type @parent { + torch.attr "m" : !torch.nn.Module<"child"> + torch.attr "m2" : !torch.nn.Module<"child"> + +} + +%c42 = std.constant 42.0 : f64 +%child = torch.nn_module { + torch.slot "float", %c42 : f64 +} : !torch.nn.Module<"child"> +%parent = torch.nn_module { + torch.slot "m", %child : !torch.nn.Module<"child"> + torch.slot "m2", %child : !torch.nn.Module<"child"> +} : !torch.nn.Module<"parent"> diff --git a/test/Dialect/Torch/globalize-object-graph-methods.mlir b/test/Dialect/Torch/globalize-object-graph-methods.mlir new file mode 100644 index 000000000..755635fad --- /dev/null +++ b/test/Dialect/Torch/globalize-object-graph-methods.mlir @@ -0,0 +1,39 @@ +// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s + +torch.class_type @c { + torch.attr "float" : f64 + torch.method "test_get", @test_get + torch.method "test_set", @test_set + torch.method "test_call", @test_call +} + +// CHECK-LABEL: func @test_get() -> f64 { +// CHECK: %[[V:.*]] = torch.global_slot.get @float : f64 +// CHECK: return %[[V]] : f64 +func private @test_get(%arg0: !torch.nn.Module<"c">) -> f64 { + %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64 + return %0 : f64 +} + +// CHECK-LABEL: func @test_set( +// CHECK-SAME: %[[A:.*]]: f64) { +// CHECK: torch.global_slot.set @float = %[[A]] : f64 +// CHECK: return +func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: f64) { + torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, f64 + return +} + +// CHECK-LABEL: func @test_call( +// CHECK-SAME: %[[A:.*]]: f64) -> f64 { +// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (f64) -> f64 +// CHECK: return %[[V]] : f64 +func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 { + %0 = torch.prim.CallMethod %arg0["test_call"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64 + return %0 : f64 +} + +%c42 = std.constant 42.0 : f64 +torch.nn_module { + torch.slot "float", %c42 : f64 +} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/globalize-object-graph-submodules.mlir b/test/Dialect/Torch/globalize-object-graph-submodules.mlir new file mode 100644 index 000000000..544640b4d --- /dev/null +++ b/test/Dialect/Torch/globalize-object-graph-submodules.mlir @@ -0,0 +1,25 @@ +// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s + +// Check that linkage names consist of the dotted path from the root. + +// CHECK-LABEL: torch.global_slot @m.float : f64 + +// CHECK-LABEL: func @__torch_global_slot_initializer() { +// CHECK: %[[C42:.*]] = constant 4.200000e+01 : f64 +// CHECK: torch.global_slot.set @m.float = %[[C42]] : f64 +// CHECK: return + +torch.class_type @child { + torch.attr "float" : f64 +} +torch.class_type @parent { + torch.attr "m" : !torch.nn.Module<"child"> +} + +%c42 = std.constant 42.0 : f64 +%child = torch.nn_module { + torch.slot "float", %c42 : f64 +} : !torch.nn.Module<"child"> +%parent = torch.nn_module { + torch.slot "m", %child : !torch.nn.Module<"child"> +} : !torch.nn.Module<"parent"> diff --git a/test/Dialect/Torch/globalize-object-graph.mlir b/test/Dialect/Torch/globalize-object-graph.mlir new file mode 100644 index 000000000..0e51d0141 --- /dev/null +++ b/test/Dialect/Torch/globalize-object-graph.mlir @@ -0,0 +1,63 @@ +// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s + +// Basic case. + +// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType +// CHECK: torch.global_slot @i : i64 +// CHECK: torch.global_slot @f : f64 +// CHECK: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype> + +// CHECK-LABEL: func @__torch_global_slot_initializer() { +// CHECK: %[[CB:.*]] = basicpy.bool_constant true +// CHECK: torch.global_slot.set @b = %[[CB]] : !basicpy.BoolType +// CHECK: %[[CI:.*]] = basicpy.numeric_constant 3 : i64 +// CHECK: torch.global_slot.set @i = %[[CI]] : i64 +// CHECK: %[[CF:.*]] = basicpy.numeric_constant 4.250000e+01 : f64 +// CHECK: torch.global_slot.set @f = %[[CF]] : f64 +// CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32> +// CHECK: %[[CA:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: torch.global_slot.set @a = %[[CA]] : !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: return + +torch.class_type @c { + torch.attr "b" : !basicpy.BoolType + torch.attr "i" : i64 + torch.attr "f" : f64 + torch.attr "a" : !numpy.ndarray<*:!numpy.any_dtype> +} + +%bool_true = basicpy.bool_constant true +%i = basicpy.numeric_constant 3 : i64 +%f = basicpy.numeric_constant 4.250000e+01 : f64 +%cst = constant dense<1.000000e+00> : tensor<1xf32> +%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> +torch.nn_module { + torch.slot "b", %bool_true : !basicpy.BoolType + torch.slot "i", %i : i64 + torch.slot "f", %f : f64 + torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype> +} : !torch.nn.Module<"c"> + +// ----- + +// Same SSA value used as initializer for multiple slots. + +// CHECK-LABEL: torch.global_slot @b1 : !basicpy.BoolType +// CHECK-LABEL: torch.global_slot @b2 : !basicpy.BoolType +// CHECK-LABEL: func @__torch_global_slot_initializer() { +// CHECK: %[[TRUE:.*]] = basicpy.bool_constant true +// CHECK: torch.global_slot.set @b1 = %[[TRUE]] : !basicpy.BoolType +// CHECK: torch.global_slot.set @b2 = %[[TRUE]] : !basicpy.BoolType +// CHECK: return +// CHECK: } + +torch.class_type @c { + torch.attr "b1" : !basicpy.BoolType + torch.attr "b2" : !basicpy.BoolType +} + +%bool_true = basicpy.bool_constant true +torch.nn_module { + torch.slot "b1", %bool_true : !basicpy.BoolType + torch.slot "b2", %bool_true : !basicpy.BoolType +} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index cd518d8e8..dcaf8e36e 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -2,14 +2,98 @@ // ----- -torch.nn_module { - // expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}} +torch.class_type @c {} +%0 = torch.nn_module { + // expected-error @+1 {{'func' op is not allowed inside 'torch.nn_module'}} + func @f() +} : !torch.nn.Module<"c"> + +// ----- + +torch.class_type @c {} +%c0 = constant 0 : i64 +// expected-error @+1 {{number of 'torch.slot's in a 'torch.nn_module' must match number of 'torch.attr's in the corresponding 'torch.class_type'}} +%0 = torch.nn_module { + torch.slot "f", %c0 : i64 +} : !torch.nn.Module<"c"> + +// ----- + +torch.class_type @c { + // expected-note @+1 {{see torch.attr at corresponding index 0 here}} + torch.attr "g" : i64 +} +%c0 = constant 0 : i64 +%0 = torch.nn_module { + // expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : i64'}} + torch.slot "f", %c0 : i64 +} : !torch.nn.Module<"c"> + +// ----- + +torch.class_type @c { + // expected-error @+1 {{'func' op is not allowed inside `torch.class_type`}} func @f() } // ----- -torch.nn_module { - // expected-error @+1 {{'invalidSym' does not reference a valid function}} +// expected-error @+1 {{has duplicate attr/method with name 'a'}} +torch.class_type @c { + // expected-note @+1 {{see first conflicting attr/method here}} + torch.attr "a" : i64 + // expected-note @+1 {{see second conflicting attr/method here}} + torch.attr "a" : i64 +} + +// ----- + +torch.class_type @c { + // expected-error @+1 {{'@invalidSym' does not reference a valid function}} torch.method "f", @invalidSym } + +// ----- + +torch.class_type @c { + // expected-error @+1 {{'@f' must reference a private function}} + torch.method "f", @f +} + +func @f(%arg0: !torch.nn.Module<"c">) { + return +} + +// ----- + +torch.class_type @c { + // expected-error @+1 {{'@f' must reference a function that is defined (not merely declared)}} + torch.method "f", @f +} + +func private @f(%arg0: !torch.nn.Module<"c">) + +// ----- + +func private @f() { + return +} +torch.class_type @c { + // expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}} + torch.method "f", @f +} + +// ----- + +func private @f(!torch.nn.Module<"other_c">) { + return +} +torch.class_type @c { + // expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}} + torch.method "f", @f +} + +// ----- + +// expected-error @+1 {{'a' does not reference a valid class type}} +%m = torch.nn_module {} : !torch.nn.Module<"a"> diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index ae003ac8e..16d610c7a 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -14,16 +14,28 @@ func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> { %num = basicpy.numeric_constant 4.250000e+01 : f64 %cst = constant dense<1.000000e+00> : tensor<1xf32> %array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> -func @f(%arg0: !torch.nn.Module) { +%none = basicpy.singleton : !basicpy.NoneType +func private @f(%arg0: !torch.nn.Module<"test">) { return } -%submodule = torch.nn_module {} -torch.nn_module { - torch.attr "b", %bool_true : !basicpy.BoolType - torch.attr "i", %num3_i64 : i64 - torch.attr "f", %num : f64 - torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype> - torch.attr "submodule", %submodule : !torch.nn.Module +torch.class_type @empty {} +%submodule = torch.nn_module {} : !torch.nn.Module<"empty"> + +torch.class_type @test { + torch.attr "b" : !basicpy.BoolType + torch.attr "i" : i64 + torch.attr "f" : f64 + torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype> + torch.attr "submodule" : !torch.nn.Module<"empty"> + torch.attr "ob" : !torch.optional torch.method "method", @f } +torch.nn_module { + torch.slot "b", %bool_true : !basicpy.BoolType + torch.slot "i", %num3_i64 : i64 + torch.slot "f", %num : f64 + torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype> + torch.slot "submodule", %submodule : !torch.nn.Module<"empty"> + torch.slot "ob", %none : !basicpy.NoneType +} : !torch.nn.Module<"test">