Implement GlobalizeObjectGraph transformation.

This required restructuring of how we model TorchScript on import. The
main difference is that now we split out a `torch.class_type` that holds
methods and declarations of the types of each slot. This is more
consistent with TorchScript (our previous representation was
"denormalized").

Recommended reading order:
1. check out the description of `torch.class_type` in `TorchOps.td` and
   look at `test/Dialect/Torch/ops.mlir` and
   `frontends/pytorch/test/module_import/` to familiarize with the new
   representation.
   - Just look at the new IR. The diff between the old names and new
     names is confusing.
2. check out `test/Dialect/Torch/globalize-object-graph*.mlir`
   and read along with the pass description in
   `include/npcomp/Dialect/Torch/Transforms/Passes.td`
3. Read the code in `GlobalizeObjectGraph.cpp` and miscellaneous changes
   in `ivalue_importer.cpp`, `TorchOps.cpp`, etc.
pull/162/head
Sean Silva 2021-02-17 11:28:51 -08:00
parent 99d1db18d2
commit 158c5c484d
34 changed files with 1275 additions and 85 deletions

View File

@ -101,7 +101,8 @@ public:
private: private:
MlirValue rawImportIValue(c10::IValue value); MlirValue rawImportIValue(c10::IValue value);
MlirValue importModule(torch::jit::Module jitModule); MlirValue importModule(torch::jit::Module jitModule);
void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody); void importMethod(torch::jit::Function *function, MlirBlock classTypeBody);
void importClassType(c10::ClassType *classType);
MlirBlock importBlock; MlirBlock importBlock;
MlirContext context; MlirContext context;
@ -111,6 +112,12 @@ private:
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap; std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
// Used to detect potentially aliasing tensors. // Used to detect potentially aliasing tensors.
std::unordered_set<c10::StorageImpl *> seenStorageImpls; std::unordered_set<c10::StorageImpl *> seenStorageImpls;
// The set of ClassType's that have already been imported.
//
// ClassType's are referenced via their `classType->name()->qualifiedName()`
// string (as an MLIR symbol name) so we don't need to keep a map associating
// them with the MlirOperation that they import into.
std::unordered_set<c10::ClassType *> classTypes;
// The stack of attribute names we have traversed to reach the current IValue. // The stack of attribute names we have traversed to reach the current IValue.
// Used for diagnostics. // Used for diagnostics.
std::vector<std::string> attributeNameStack; std::vector<std::string> attributeNameStack;
@ -128,16 +135,25 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
// TODO: Can we do better? // TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
MlirOperation nnModule = c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
createMlirOperation("torch.nn_module", loc, if (!maybeName) {
npcompNnModuleTypeGet(context), mlirRegionCreate()); throw std::invalid_argument("cannot import unnamed module");
}
std::string moduleTypeName = maybeName->qualifiedName();
// Ensure the class type has been imported.
importClassType(currentModule.type().get());
MlirOperation nnModule = createMlirOperation(
"torch.nn_module", loc,
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
if (!rootModuleName.has_value()) { if (!rootModuleName.has_value()) {
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name(); rootModuleName = moduleTypeName;
rootModuleName = maybeName ? maybeName->qualifiedName() : "unnamed module";
} }
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots(); const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
@ -151,7 +167,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
MlirValue slotValue = importIValue(slots[i]); MlirValue slotValue = importIValue(slots[i]);
// TODO: Is it necessary to track whether an attribute is a "parameter"? // TODO: Is it necessary to track whether an attribute is a "parameter"?
createMlirOperationAtEnd( createMlirOperationAtEnd(
nnModuleBody, "torch.attr", loc, slotValue, nnModuleBody, "torch.slot", loc, slotValue,
toMlirNamedAttribute( toMlirNamedAttribute(
"name", mlirStringAttrGet( "name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName())))); context, toMlirStringRef(classAttribute.getName()))));
@ -162,10 +178,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
rootModuleName = c10::nullopt; rootModuleName = c10::nullopt;
} }
for (torch::jit::Function *function : currentModule.type()->methods()) {
importMethod(function, nnModuleBody);
}
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc); createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
mlirBlockInsertOwnedOperationBefore( mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule); importBlock, mlirBlockGetTerminator(importBlock), nnModule);
@ -262,7 +274,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
} }
void IValueImporter::importMethod(torch::jit::Function *function, void IValueImporter::importMethod(torch::jit::Function *function,
MlirBlock nnModuleBody) { MlirBlock classTypeBody) {
// We make an effort for the func op's symbol name to be useful for debugging, // We make an effort for the func op's symbol name to be useful for debugging,
// but still clearly non-load-bearing. // but still clearly non-load-bearing.
std::string symName = std::string symName =
@ -275,13 +287,50 @@ void IValueImporter::importMethod(torch::jit::Function *function,
mlirBlockInsertOwnedOperationBefore( mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), func); importBlock, mlirBlockGetTerminator(importBlock), func);
createMlirOperationAtEnd( createMlirOperationAtEnd(
nnModuleBody, "torch.method", mlirLocationUnknownGet(context), classTypeBody, "torch.method", mlirLocationUnknownGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"name", "name",
mlirStringAttrGet(context, toMlirStringRef(function->name()))), mlirStringAttrGet(context, toMlirStringRef(function->name()))),
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet( toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
context, toMlirStringRef(symName)))); context, toMlirStringRef(symName))));
} }
void IValueImporter::importClassType(c10::ClassType *classType) {
if (!classTypes.insert(classType).second) {
return;
}
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
MlirOperation op = createMlirOperationAtEnd(
importBlock, "torch.class_type", loc, mlirRegionCreate(),
toMlirNamedAttribute(
"sym_name",
mlirStringAttrGet(
context, toMlirStringRef(classType->name()->qualifiedName()))));
MlirRegion region = mlirOperationGetRegion(op, 0);
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
for (const c10::ClassAttribute &classAttribute : classType->getAttributes()) {
createMlirOperationAtEnd(
classTypeBody, "torch.attr", loc,
toMlirNamedAttribute(
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type",
mlirTypeAttrGet(typeMapper.mapFromTorchType(
loc, classAttribute.getType()))));
}
for (torch::jit::Function *function : classType->methods()) {
importMethod(function, classTypeBody);
}
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
}
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block, void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirContext context) { MlirContext context) {
// When debugging module importing, it can be useful to dump as so: // When debugging module importing, it can be useful to dump as so:

View File

@ -106,11 +106,19 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType); return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
} }
case TypeKind::ClassType: { case TypeKind::ClassType: {
return npcompNnModuleTypeGet(context); auto maybeName = torchType->cast<c10::ClassType>()->name();
return npcompNnModuleTypeGet(
context, toMlirStringRef(maybeName ? maybeName->qualifiedName()
: "unnamed class"));
} }
case TypeKind::FloatType: { case TypeKind::FloatType: {
return mlirF64TypeGet(context); return mlirF64TypeGet(context);
} }
case TypeKind::OptionalType: {
return npcompOptionalTypeGet(
mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::IntType: { case TypeKind::IntType: {
return mlirIntegerTypeGet(context, 64); return mlirIntegerTypeGet(context, 64);
} }
@ -120,6 +128,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
case TypeKind::BoolType: { case TypeKind::BoolType: {
return npcompBoolTypeGet(context); return npcompBoolTypeGet(context);
} }
case TypeKind::ListType: {
// TODO: Don't lose the element type information.
return npcompListTypeGet(context);
}
default: { default: {
std::stringstream message; std::stringstream message;
message << "unable to map Torch type " << *torchType << " to MLIR type"; message << "unable to map Torch type " << *torchType << " to MLIR type";

View File

@ -15,13 +15,16 @@ class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.l = [1, 2] self.l = [1, 2]
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# TODO: Don't lose element type.
# CHECK: torch.attr "l" : !basicpy.ListType
# CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64 # CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64 # CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType # CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
# CHECK: torch.nn_module { # CHECK: torch.nn_module {
# CHECK: torch.attr "l", %[[LIST]] : !basicpy.ListType # CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType
# CHECK: } # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
test_module = TestModule() test_module = TestModule()

View File

@ -32,8 +32,8 @@ class TestModule(torch.nn.Module):
# the case that the name is `__main__` Torch replaces it with `__torch__` to # the case that the name is `__main__` Torch replaces it with `__torch__` to
# avoid collisions. # avoid collisions.
# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward # CHECK-DAG: func private @__npcomp_priv_fn.__torch__.TestModule.forward
# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward # CHECK=DAG: func private @__npcomp_priv_fn.__torch__.Submodule.forward
test_module = TestModule() test_module = TestModule()

View File

@ -19,17 +19,22 @@ class TestModule(torch.nn.Module):
# The symbol name of the function is NOT load-bearing and cannot be relied upon. # The symbol name of the function is NOT load-bearing and cannot be relied upon.
# CHECK-LABEL: torch.class_type
# CHECK-SAME: @[[CLASSTYPE:.*]] {
# CHECK: torch.method "forward", @[[SYMNAME:.*]]
# CHECK: }
# CHECK-LABEL: func private # CHECK-LABEL: func private
# CHECK-SAME: @[[SYMNAME:.*]]( # CHECK-SAME: @[[SYMNAME]](
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module, # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"[[CLASSTYPE]]">,
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { # CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]] # CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: %[[ROOT:.*]] = torch.nn_module { # CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.method "forward", @[[SYMNAME]] # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
# CHECK: }
test_module = TestModule() test_module = TestModule()

View File

@ -24,8 +24,8 @@ class TestModule(torch.nn.Module):
# CHECK: %[[L2:.*]] = basicpy.build_list # CHECK: %[[L2:.*]] = basicpy.build_list
# CHECK: %[[L1:.*]] = basicpy.build_list # CHECK: %[[L1:.*]] = basicpy.build_list
# CHECK: torch.nn_module { # CHECK: torch.nn_module {
# CHECK: torch.attr "l2", %[[L2]] # CHECK: torch.slot "l2", %[[L2]]
# CHECK: torch.attr "l1", %[[L1]] # CHECK: torch.slot "l1", %[[L1]]
self.l2 = self.l1 = [1] self.l2 = self.l1 = [1]
# This can be uncommented when the graph importer supports it. # This can be uncommented when the graph importer supports it.

View File

@ -16,8 +16,8 @@ class TestModule(torch.nn.Module):
super().__init__() super().__init__()
# CHECK: %[[A:.*]] = numpy.create_array_from_tensor # CHECK: %[[A:.*]] = numpy.create_array_from_tensor
# CHECK: torch.nn_module { # CHECK: torch.nn_module {
# CHECK: torch.attr "t1", %[[A]] # CHECK: torch.slot "t1", %[[A]]
# CHECK: torch.attr "t2", %[[A]] # CHECK: torch.slot "t2", %[[A]]
self.t1 = self.t2 = torch.tensor([10., 20.]) self.t1 = self.t2 = torch.tensor([10., 20.])

View File

@ -18,7 +18,7 @@ class TestModule(torch.nn.Module):
self.t2 = torch.ones(1) self.t2 = torch.ones(1)
# CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}( # CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module) -> !basicpy.NoneType { # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !basicpy.NoneType {
def forward(self): def forward(self):
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"] # CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
# CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]] # CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]]
@ -26,7 +26,7 @@ class TestModule(torch.nn.Module):
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}}) # CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
self.callee(self.t1, self.t2) self.callee(self.t1, self.t2)
# CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}( # CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module, # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype> # CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
def callee(self, x, y): def callee(self, x, y):

View File

@ -17,14 +17,20 @@ class TestModule(torch.nn.Module):
self.i = 3 self.i = 3
self.f = 42.5 self.f = 42.5
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# CHECK: torch.attr "training" : !basicpy.BoolType
# CHECK: torch.attr "i" : i64
# CHECK: torch.attr "f" : f64
# CHECK: }
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true # CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64 # CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64 # CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
# CHECK: %[[MODULE:.*]] = torch.nn_module { # CHECK: %[[MODULE:.*]] = torch.nn_module {
# Note: for some reason, Torch always adds a "training" property to all modules. # Note: for some reason, Torch always adds a "training" property to all modules.
# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType # CHECK: torch.slot "training", %[[TRUE]] : !basicpy.BoolType
# CHECK: torch.attr "i", %[[N3]] : i64 # CHECK: torch.slot "i", %[[N3]] : i64
# CHECK: torch.attr "f", %[[N42]] : f64 # CHECK: torch.slot "f", %[[N42]] : f64
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE:.*]]">
test_module = TestModule() test_module = TestModule()

View File

@ -15,6 +15,8 @@ class Submodule(torch.nn.Module):
def __init__(self, n): def __init__(self, n):
super().__init__() super().__init__()
self.n = n self.n = n
def forward(self):
return self.n
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -27,9 +29,9 @@ class TestModule(torch.nn.Module):
# Modules with the same class can be selected between. # Modules with the same class can be selected between.
# CHECK: %[[MOD:.*]] = scf.if # CHECK: %[[MOD:.*]] = scf.if
s = self.s1 if b else self.s2 s = self.s1 if b else self.s2
# CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"] # CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] ()
# CHECK: return %[[N]] # CHECK: return %[[N]]
return s.n return s.forward()
test_module = TestModule() test_module = TestModule()

View File

@ -26,20 +26,20 @@ class TestModule(torch.nn.Module):
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64 # CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
# CHECK: %[[S0:.*]] = torch.nn_module { # CHECK: %[[S0:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType # CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
# CHECK: torch.attr "n", %[[N0]] : i64 # CHECK: torch.slot "n", %[[N0]] : i64
# CHECK: } # CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64 # CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[S1:.*]] = torch.nn_module { # CHECK: %[[S1:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType # CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
# CHECK: torch.attr "n", %[[N1]] : i64 # CHECK: torch.slot "n", %[[N1]] : i64
# CHECK: } # CHECK: }
# CHECK: %[[ROOT:.*]] = torch.nn_module { # CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType # CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module # CHECK: torch.slot "s0", %[[S0]] : !torch.nn.Module
# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module # CHECK: torch.slot "s1", %[[S1]] : !torch.nn.Module
# CHECK: } # CHECK: }

View File

@ -23,8 +23,8 @@ class TestModule(torch.nn.Module):
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32> # CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> # CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: %[[ROOT:.*]] = torch.nn_module { # CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: torch.slot "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: torch.slot "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: } # CHECK: }

View File

@ -124,8 +124,18 @@ MlirType npcompTupleTypeGet(MlirContext context);
/** Checks whether the given type is a torch.nn.Module type */ /** Checks whether the given type is a torch.nn.Module type */
int npcompTypeIsANnModule(MlirType t); int npcompTypeIsANnModule(MlirType t);
/** Gets the singleton torch.nn.Module type. */ /** Gets the !torch.nn.Module type of the specified class. */
MlirType npcompNnModuleTypeGet(MlirContext context); MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
/*============================================================================*/
/* torch.optional type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.optional<T> type */
int npcompTypeIsAOptional(MlirType t);
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -1 +1,2 @@
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -49,10 +49,11 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TorchScript modeling ops. // TorchScript `torch.nn.Module` object instantiation ops.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def Torch_NnModuleOp : Torch_Op<"nn_module", [ def Torch_NnModuleOp : Torch_Op<"nn_module", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> { SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
let summary = "Constructs a torch.nn.Module"; let summary = "Constructs a torch.nn.Module";
let description = [{ let description = [{
@ -65,15 +66,19 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
Example: Example:
```mlir ```mlir
%2 = torch.nn_module { %2 = torch.nn_module {
torch.attr "b", %bool_true : !basicpy.BoolType torch.slot "b", %bool_true : !basicpy.BoolType
torch.attr "i", %num3_i64 : i64 torch.slot "i", %num3_i64 : i64
torch.attr "f", %num : f64 torch.slot "f", %num : f64
torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype> torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule", %1 : !torch.nn.Module torch.slot "submodule", %1 : !torch.nn.Module
torch.method "method", @f } : !torch.nn.Module<"my_class_name">
}
``` ```
This op is tightly coupled to the `torch.class_type` op named in the
`!torch.nn.Module<"my_class_name">` type. Each slot must match precisely
with the corresponding `torch.attr` in the `torch.class_type`.
See the documentation for `torch.class_type` for information.
}]; }];
let arguments = (ins); let arguments = (ins);
@ -81,7 +86,11 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
let regions = (region SizedRegion<1>:$region); let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);"; let verifier = "return ::verify(*this);";
let assemblyFormat = "$region attr-dict"; let assemblyFormat = "$region attr-dict `:` type($result)";
let extraClassDeclaration = [{
StringRef getClassName() { return getType().getClassName(); }
}];
} }
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator, def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
@ -94,13 +103,13 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
let assemblyFormat = "attr-dict"; let assemblyFormat = "attr-dict";
} }
def Torch_AttrOp : Torch_Op<"attr", [ def Torch_SlotOp : Torch_Op<"slot", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> { HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
let summary = "Define an attribute of a torch.nn.Module"; let summary = "Define the value of a slot of a torch.nn.Module";
let description = [{ let description = [{
This op declaratively specifies that the parent torch.nn_module has an This op specifies that the initial value of the slot `name` of the
attribute `name` with value `value`, which is allowed to be an arbitrary parent torch.nn_module should be `value`, which is allowed to be an
Torch-compatible SSA value, including other torch.nn.Module's. arbitrary Torch-compatible SSA value, including other !torch.nn.Module's.
}]; }];
let arguments = (ins StrAttr:$name, AnyTorchType:$value); let arguments = (ins StrAttr:$name, AnyTorchType:$value);
@ -111,13 +120,73 @@ def Torch_AttrOp : Torch_Op<"attr", [
}]; }];
} }
//===----------------------------------------------------------------------===//
// Modeling of TorchScript class types
//===----------------------------------------------------------------------===//
def Torch_ClassTypeOp : Torch_Op<"class_type", [
Symbol,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
let summary = "Constructs a torch.ClassType";
let description = [{
Declares a class type. Class types are the types used to describe
TorchScript `torch.nn.Module`'s. The terminology "class type" is for
consistency with TorchScript (a better name in our context might be
"nn module subtype"). The `syn_name` of this op is the same string
as in the `!torch.nn.Module<"...">` type.
Example:
```mlir
// A simple empty torch.class_type, with corresponding torch.nn_module.
torch.class_type @empty {}
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
// A class type with many members.
torch.class_type @test {
torch.attr "b" : !basicpy.BoolType
torch.attr "i" : i64
torch.attr "f" : f64
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule" : !torch.nn.Module<"empty">
torch.method "method", @f
}
torch.nn_module {
// These must match the order and names in the `torch.class_type`.
torch.slot "b", %bool_true : !basicpy.BoolType
torch.slot "i", %num3_i64 : i64
torch.slot "f", %num : f64
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
} : !torch.nn.Module<"test">
```
}];
let arguments = (ins SymbolNameAttr:$sym_name);
let results = (outs);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let assemblyFormat = "$sym_name $region attr-dict";
}
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
let summary = "Implicit terminator for torch.class_type";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}
def Torch_MethodOp : Torch_Op<"method", [ def Torch_MethodOp : Torch_Op<"method", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">, HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
DeclareOpInterfaceMethods<SymbolUserOpInterface> DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> { ]> {
let summary = "Define a method of a torch.nn.Module"; let summary = "Declare a method of a torch.class_type";
let description = [{ let description = [{
This op declaratively specifies that the parent torch.nn_module has a This op declaratively specifies that the parent torch.class_type has a
method `name` which calls `function`. `function` is an unbound function. method `name` which calls `function`. `function` is an unbound function.
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
"self" object). "self" object).
@ -131,6 +200,88 @@ def Torch_MethodOp : Torch_Op<"method", [
}]; }];
} }
def Torch_AttrOp : Torch_Op<"attr", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
]> {
let summary = "Declare an attribute of a torch.class_type";
let description = [{
This op declaratively specifies that torch.nn.Module's of the parent
torch.class_type must have an attribute `name` of type `type`.
}];
let arguments = (ins StrAttr:$name, TypeAttr:$type);
let results = (outs);
let assemblyFormat = [{
$name `:` $type attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Global slot ops
//===----------------------------------------------------------------------===//
// TODO: Should these be in a separate dialect?
// At this point, they are fairly specific to torch types, but their get/set
// semantics follow Python.
//===----------------------------------------------------------------------===//
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
Symbol,
IsolatedFromAbove,
]> {
let summary = "A slot with global storage";
let description = [{
Represents a slot with global storage. The slot semantics are the same
as Python's: getting or setting a slot is done by object identity.
}];
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$typeBound);
let results = (outs);
let assemblyFormat = [{
$sym_name attr-dict `:` $typeBound
}];
let extraClassDeclaration = [{
// The name of the function, which, for semantic correctness, must be called
// exactly once and this call must be done before any other calls into
// the module.
// TODO: Avoid load-bearing names.
// We could replace this with an op that marks the function as initializer.
static constexpr StringRef getGlobalSlotInitializerFuncName() {
return "__torch_global_slot_initializer";
}
}];
}
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
let summary = "Get the value stored in a torch.global_slot";
let arguments = (ins
FlatSymbolRefAttr:$slot
);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$slot attr-dict `:` type($result)
}];
}
def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
let summary = "Set the value stored in a torch.global_slot";
let arguments = (ins
FlatSymbolRefAttr:$slot,
AnyTorchType:$value
);
let results = (outs);
let assemblyFormat = [{
$slot `=` $value attr-dict `:` type($value)
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TorchScript `prim::` ops. // TorchScript `prim::` ops.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -142,7 +293,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
let results = (outs AnyTorchType:$result); let results = (outs AnyTorchType:$result);
let assemblyFormat = [{ let assemblyFormat = [{
$receiver `[` $name `]` attr-dict `:` type($result) $receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
}]; }];
} }
@ -157,7 +308,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
let results = (outs); let results = (outs);
let assemblyFormat = [{ let assemblyFormat = [{
$receiver `[` $name `]` `=` $value attr-dict `:` type($value) $receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
}]; }];
} }
@ -172,7 +323,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
let results = (outs AnyTorchType:$result); let results = (outs AnyTorchType:$result);
let assemblyFormat = [{ let assemblyFormat = [{
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($operands) `->` type($result) $receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
}]; }];
} }

View File

@ -22,9 +22,62 @@ class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> { def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
let summary = "torch.nn.Module"; let summary = "torch.nn.Module";
let description = [{ let description = [{
Represents an instance of a `torch.nn.Module` with the given `className`.
}];
let parameters = (ins StringRefParameter<"class name">:$className);
let printer = [{
$_printer << "nn.Module<\"";
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
$_printer << "\">";
}];
let parser = [{
if (parser.parseLess())
return Type();
StringRef className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, className);
}]; }];
} }
// TODO: It feels like this should be something more general.
// However, to do that, we need to agree on construction operations
// and the valid MLIR representations of the "None" state.
//
// For now, we only need it as a stand-in type to allow importing
// the `_is_full_backward_hook` optional bool type that Torch puts on
// all classes.
def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
let summary = "!torch.optional<T>";
let description = [{
}];
let parameters = (ins "::mlir::Type":$containedType);
let printer = [{
$_printer << "optional<" << getImpl()->containedType << ">";
}];
let parser = [{
if (parser.parseLess())
return Type();
Type containedType;
if ($_parser.parseType(containedType))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, containedType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType);
}]>
];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type predicates // Type predicates

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPTorchPassIncGen)
add_mlir_doc(Passes -gen-pass-doc NPCOMPTorchTransforms ./)

View File

@ -0,0 +1,30 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
namespace Torch {
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
} // namespace Torch
/// Registers all Torch transformation passes.
void registerTorchPasses();
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H

View File

@ -0,0 +1,65 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_TORCH_PASSES
#define NPCOMP_TORCH_PASSES
include "mlir/Pass/PassBase.td"
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
let summary = "Converts TorchScript object graphs to a globalized form";
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
let description = [{
This pass converts a subset of possible TorchScript modules into a
more restrictive lower-level form that strips away the need to be
concerned with instances of !torch.nn.Module<...> type. Specifically,
the object graph is flattened into a set of discrete globals
(`torch.global_slot`) that hold the program state.
The overarching goal is for a strict correspondence between the original
`torch.nn.Module` (call it `root`) that the user `torch.jit.script`'ed, and
the public interface of the resulting MLIR module. Specifically:
- The call `root.encoder.forward(...)` in Python corresponds to invoking
the `func @encoder.forward` on the resulting MLIR module.
- The data member access `root.decoder.ids_to_strings_table` in Python
corresponds to accessing the
`torch.global_slot @decoder.ids_to_strings_table` on the resulting
MLIR module.
In effect, the entire MLIR module corresponds to an instance of the `root`
object. This matches with the intuitive behavior desired for deployment:
When the MLIR module (or, more likely, a compiled artifact derived from it)
is loaded in a deployed environment, it is equivalent to recreating the
original `root` object.
This pass performs a complete change of the externally visible calling
convention of the MLIR module for a graph of objects and methods to a
fixed set of globals and functions.
Of course, only a subset of programs can be transformed, and this pass fails
with an error if the conditions are violated.
Specifically, the restrictions are:
- There must be a unique torch.nn_module that is not the value of a slot
of any other torch.nn_module
- Rationale: Allows us to have a notion of a unique "root" op, which is
used to define linkage. This also matches how TorchScript imports in
practice (`torch.jit.script` imports a single root object).
- There must be exactly one instance of each torch.class_type. Equivalently,
Every torch.nn_module must have a distinct type.
- Rationale: This guarantee precludes things like selecting between
multiple modules dynamically at runtime, which would require indirecting
between the separate storage of each instance.
- All torch.nn_module's must be reachable by a unique path from the root
- Rationale: Eliminates possibility of potentially exponential number of
paths. Or worse, infinite number of paths when considering cyclic
object graphs. Also as of Feb 2021, TorchScript won't import into
this form (it has a bug related to the identity of submodules).
}];
}
#endif // NPCOMP_TORCH_PASSES

View File

@ -156,7 +156,21 @@ int npcompTypeIsANnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>(); return unwrap(t).isa<Torch::NnModuleType>();
} }
/** Gets the singleton torch.nn.Module type. */ /** Gets the torch.nn.Module type of the specified class. */
MlirType npcompNnModuleTypeGet(MlirContext context) { MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
return wrap(Torch::NnModuleType::get(unwrap(context))); return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
}
/*============================================================================*/
/* torch.optional type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.optional<T> type */
int npcompTypeIsAOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>();
}
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
} }

View File

@ -1 +1,2 @@
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -10,6 +10,7 @@
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
using namespace mlir; using namespace mlir;

View File

@ -14,6 +14,7 @@
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "llvm/ADT/StringMap.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -49,8 +50,23 @@ KernelMetadata KernelCallOp::getTorchKernelMetadata() {
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function()); auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
if (!func) if (!func)
return emitError() << "'" << function() return emitError() << "'@" << function()
<< "' does not reference a valid function"; << "' does not reference a valid function";
if (func.getVisibility() != SymbolTable::Visibility::Private)
return emitError() << "'@" << function()
<< "' must reference a private function";
if (func.isDeclaration())
return emitError() << "'@" << function()
<< "' must reference a function that is defined (not "
"merely declared)";
auto expectedReceiverArgType = NnModuleType::get(
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
if (func.getType().getNumInputs() == 0 ||
func.getType().getInput(0) != expectedReceiverArgType) {
return emitError() << "the referenced function '" << function()
<< "' must have a first argument of type "
<< expectedReceiverArgType;
}
return success(); return success();
} }
@ -60,8 +76,82 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
static LogicalResult verify(NnModuleOp op) { static LogicalResult verify(NnModuleOp op) {
for (Operation &child : *op.getBody()) for (Operation &child : *op.getBody())
if (!isa<AttrOp, MethodOp, NnModuleTerminatorOp>(&child)) if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
return child.emitOpError() << "is not allowed inside `torch.nn_module`"; return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
return success();
}
// PyTorch has a well-developed notion of subtyping.
//
// This is a restricted subset of it.
//
// TODO: Flesh this out.
bool isValidSubtype(Type subtype, Type type) {
if (subtype == type)
return true;
if (auto optional = type.dyn_cast<OptionalType>())
return subtype == optional.getContainedType() ||
subtype.isa<Basicpy::NoneType>();
return false;
}
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto classType =
symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(*this, getClassName());
if (!classType)
return emitError() << "'" << getClassName()
<< "' does not reference a valid class type";
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
if (attrs.size() != attrDefs.size())
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
"match number of 'torch.attr's in "
"the corresponding 'torch.class_type'";
for (int i = 0, e = attrs.size(); i != e; i++) {
SlotOp attr = attrs[i];
AttrOp attrDef = attrDefs[i];
if (!isValidSubtype(attr.value().getType(), attrDef.type()) ||
attr.name() != attrDef.name()) {
return attr.emitOpError()
.append("is expected to match type and name of '",
attrDef.getOperation(), "'")
.attachNote(attrDef.getLoc())
.append("see torch.attr at corresponding index ", i, " here");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ClassTypeOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ClassTypeOp op) {
llvm::StringMap<Operation *> namesToOps;
for (Operation &child : op.getBody()->without_terminator()) {
if (!isa<AttrOp, MethodOp>(&child))
return child.emitOpError() << "is not allowed inside `torch.class_type`";
StringRef name;
if (auto attr = dyn_cast<AttrOp>(child))
name = attr.name();
else
name = cast<MethodOp>(child).name();
auto itAndWasInserted = namesToOps.insert({name, &child});
auto it = itAndWasInserted.first;
bool wasInserted = itAndWasInserted.second;
if (!wasInserted) {
auto diag = op.emitOpError().append(
"has duplicate attr/method with name '", name, "'");
diag.attachNote(it->second->getLoc())
.append("see first conflicting attr/method here");
diag.attachNote(child.getLoc())
.append("see second conflicting attr/method here");
return failure();
}
}
return success(); return success();
} }

View File

@ -0,0 +1,19 @@
add_npcomp_conversion_library(NPCOMPTorchPasses
Passes.cpp
GlobalizeObjectGraph.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
DEPENDS
NPCOMPTorchPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
NPCOMPTorchDialect
NPCOMPBasicpyDialect
)

View File

@ -0,0 +1,340 @@
//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/MapVector.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
namespace {
// See the pass documentation for `torch-globalize-object-graph`.
class ObjectGraphGlobalizer {
public:
ObjectGraphGlobalizer(ModuleOp module);
LogicalResult globalizeObjectGraph();
private:
FailureOr<NnModuleOp> findRootNnModule();
LogicalResult checkSingleInstanceOfEachClass();
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
void createInitializerFunc();
LogicalResult rewriteMethods();
void removeObjectGraph();
ModuleOp module;
SymbolTable symbolTable;
OpBuilder globalBuilder;
// The stack of attribute names we have traversed during our recursive
// traversal of the class/object hierarchy.
//
// Linkage names are calculated based on the set of attribute names traversed
// from the root class/module in the program.
SmallVector<std::string> nameStack;
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
// + attr name since frequently that is all one has access to and it
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
// with the corresponding name.
using AttrOfClass =
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
// The initial value associated with an attribute of a class.
// Since we only allow a single instance of a class, this is equivalent to
// the initial value of the unique slot corresponding to that attr.
DenseMap<AttrOfClass, Value> slotInitialValues;
// The inverse map of `slotInitialValues`.
// Many attributes can have the same initial value, so the value type
// is a vector.
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
// The torch.global_slot corresponding to each torch.attr/torch.slot.
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
// The linkage name (value) for the function with symbol name (key).
DenseMap<StringRef, std::string> methodLinkageNames;
// The set of class types that have already been processed.
// Used for diagnostics.
// The map value is the original path from the root that we found it at.
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
};
} // namespace
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
: module(module), symbolTable(module),
globalBuilder(module.getBodyRegion()) {}
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
// We require there to be a unique root !torch.nn.Module.
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
if (failed(maybeRootNnModule))
return failure();
NnModuleOp rootNnModule = *maybeRootNnModule;
if (!rootNnModule)
return module.emitError()
<< "module does not contain a root torch.nn_module";
// We require one instance of each class. That is, there is a single
// torch.nn_module for each torch.class_type.
if (failed(checkSingleInstanceOfEachClass()))
return failure();
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
for (auto slot : nnModule.getOps<SlotOp>()) {
AttrOfClass attrOfClass = {classType, slot.name()};
slotInitialValues[attrOfClass] = slot.value();
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
}
}
// Recursively traverse the class hierarchy, globalizing slots and
// tracking linkage names for methods.
auto rootClassType =
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
if (failed(recursivelyTraverseClassType(rootClassType)))
return failure();
// Move all slot initial values into an initializer func.
createInitializerFunc();
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
if (failed(rewriteMethods()))
return failure();
// Now that all we have finished converting to the new form, remove
// the original object graph.
removeObjectGraph();
return success();
}
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
NnModuleOp rootNnModule;
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
if (!op.use_empty())
continue;
if (rootNnModule) {
op.emitError()
.append("found more than one root module (module that is not a "
"child of any other module)")
.attachNote(rootNnModule.getLoc())
.append("see other root module here");
return failure();
}
rootNnModule = op;
}
return rootNnModule;
}
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
classInstances;
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
classInstances[classType].push_back(op);
}
for (auto &p : classInstances) {
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
ArrayRef<NnModuleOp> instances = p.second;
if (instances.size() > 1) {
// TODO: Improve this diagnostic based on user use cases.
// This is a user-facing diagnostic that enforces key invariants to
// our TorchScript subset.
auto diag = classType.emitError(
"class type has more than one instance: the current TorchScript "
"supported subset only allows single instances");
for (NnModuleOp instance : instances) {
diag.attachNote(instance.getLoc()) << "see instance here";
}
return failure();
}
}
return success();
}
LogicalResult
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
return classType.emitError()
<< "reachable by multiple paths from root object: '<root>."
<< seenClassTypes[classType] << "' and '<root>."
<< pathToClassFromRoot << "'";
}
// For each attr, create a global slot for it.
for (auto attr : classType.getOps<AttrOp>()) {
nameStack.push_back(attr.name().str());
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
recursivelyTraverseClassType(
symbolTable.lookup<ClassTypeOp>(type.getClassName()));
} else {
auto linkageName = llvm::join(nameStack, ".");
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
attr->getLoc(), linkageName, TypeAttr::get(attr.type()));
AttrOfClass attrOfClass = {classType, attr.name()};
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
globalSlotForAttr[attrOfClass] = globalSlot;
}
nameStack.pop_back();
}
// For each method, track the linkage name it will eventually have.
for (auto method : classType.getOps<MethodOp>()) {
nameStack.push_back(method.name().str());
auto linkageName = llvm::join(nameStack, ".");
nameStack.pop_back();
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
method.emitError() << "unbound function shared by multiple methods";
}
return success();
}
void ObjectGraphGlobalizer::createInitializerFunc() {
auto loc = module.getLoc();
auto func = globalBuilder.create<FuncOp>(
loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(),
globalBuilder.getFunctionType({}, {}));
OpBuilder builder(func.getContext());
Block *body = builder.createBlock(&func.getBody());
SmallVector<Operation *> opsToMove;
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, FuncOp, ModuleTerminatorOp>(
&op))
continue;
op.moveBefore(body, body->end());
for (Value result : llvm::make_early_inc_range(op.getResults())) {
auto it = slotInitialValuesInverseMap.find(result);
if (it == slotInitialValuesInverseMap.end())
continue;
for (AttrOfClass attrOfClass : it->second) {
GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass];
OpBuilder::atBlockEnd(body).create<GlobalSlotSetOp>(
globalSlot.getLoc(), globalSlot.sym_name(), result);
}
}
}
builder.create<ReturnOp>(loc);
}
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
DenseMap<AttrOfClass, StringRef> linkageNames;
for (auto classType : module.getOps<ClassTypeOp>()) {
for (auto method : classType.getOps<MethodOp>()) {
auto it = methodLinkageNames.find(method.function());
if (it == methodLinkageNames.end())
continue;
linkageNames[{classType, method.name()}] = it->second;
}
}
// We only handle a small subset of ops that conform with the set of
// assumptions that allow us to globalize the object graph. Anything that
// tries to treat modules as bona-fide objects and not just namespaces
// of methods with a single instance of the corresponding type just gets
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
// of modules, or there is an scf.if selecting between modules, etc.
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
auto classType = symbolTable.lookup<ClassTypeOp>(
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
OpBuilder(primSetAttr)
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
primSetAttr.value());
primSetAttr.erase();
}
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
auto classType = symbolTable.lookup<ClassTypeOp>(
primGetAttr.receiver().getType().cast<NnModuleType>().getClassName());
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
auto globalSlotGet = OpBuilder(primGetAttr)
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
primGetAttr.getType(),
globalSlot.sym_name());
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
primGetAttr.erase();
}
if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
.getType()
.cast<NnModuleType>()
.getClassName());
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
auto call = OpBuilder(primCallMethod)
.create<CallOp>(primCallMethod.getLoc(), linkageName,
primCallMethod.getType(),
primCallMethod.operands());
primCallMethod.replaceAllUsesWith(call);
primCallMethod.erase();
}
};
for (auto classType : module.getOps<ClassTypeOp>()) {
for (auto method : classType.getOps<MethodOp>()) {
auto it = methodLinkageNames.find(method.function());
if (it == methodLinkageNames.end())
continue;
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
func.setVisibility(SymbolTable::Visibility::Public);
func.setName(it->second);
func.walk(rewriteOpWithNnModuleTypeOperand);
SmallVector<unsigned> argsToErase;
for (auto arg : llvm::enumerate(func.getArguments())) {
if (!arg.value().getType().isa<NnModuleType>())
continue;
if (!arg.value().use_empty()) {
// TODO: Improve this based on real user use cases.
// This is a diagnostic that users will hit if they do not conform to
// the supported subset of TorchScript.
auto diag = func.emitError().append(
"func argument at index ", arg.index(),
" has uses that were not able to be converted");
for (Operation *user : arg.value().getUsers())
diag.attachNote(user->getLoc()).append("see user here");
return failure();
}
argsToErase.push_back(arg.index());
}
func.eraseArguments(argsToErase);
}
}
return success();
}
void ObjectGraphGlobalizer::removeObjectGraph() {
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
if (isa<ClassTypeOp, NnModuleOp>(op))
op.erase();
}
}
namespace {
class GlobalizeObjectGraphPass
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
void runOnOperation() override {
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
return std::make_unique<GlobalizeObjectGraphPass>();
}

View File

@ -0,0 +1,25 @@
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
namespace Torch {
#define GEN_PASS_CLASSES
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace Torch
} // namespace NPCOMP
} // end namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H

View File

@ -0,0 +1,20 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); }

View File

@ -21,6 +21,7 @@
#include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/Transforms/Passes.h" #include "npcomp/Dialect/TCP/Transforms/Passes.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "npcomp/Typing/Transforms/Passes.h" #include "npcomp/Typing/Transforms/Passes.h"
#include "npcomp/Conversion/Passes.h" #include "npcomp/Conversion/Passes.h"
@ -47,5 +48,6 @@ void mlir::NPCOMP::registerAllPasses() {
mlir::NPCOMP::registerNumpyPasses(); mlir::NPCOMP::registerNumpyPasses();
mlir::NPCOMP::registerTCFPasses(); mlir::NPCOMP::registerTCFPasses();
mlir::NPCOMP::registerTCPPasses(); mlir::NPCOMP::registerTCPPasses();
mlir::NPCOMP::registerTorchPasses();
mlir::NPCOMP::registerTypingPasses(); mlir::NPCOMP::registerTypingPasses();
} }

View File

@ -0,0 +1,63 @@
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
torch.class_type @c1 {}
torch.class_type @c2 {}
// expected-note @+1 {{see other root module here}}
torch.nn_module {} : !torch.nn.Module<"c1">
// expected-error @+1 {{found more than one root module (module that is not a child of any other module)}}
torch.nn_module {} : !torch.nn.Module<"c2">
// -----
// expected-error @+1 {{class type has more than one instance: the current TorchScript supported subset only allows single instances}}
torch.class_type @child {}
torch.class_type @parent {
torch.attr "m1" : !torch.nn.Module<"child">
torch.attr "m2" : !torch.nn.Module<"child">
}
// expected-note @+1 {{see instance here}}
%0 = torch.nn_module {} : !torch.nn.Module<"child">
// expected-note @+1 {{see instance here}}
%1 = torch.nn_module {} : !torch.nn.Module<"child">
%root = torch.nn_module {
torch.slot "m1", %0 : !torch.nn.Module<"child">
torch.slot "m2", %1 : !torch.nn.Module<"child">
} : !torch.nn.Module<"parent">
// -----
torch.class_type @c {
torch.method "f", @f
}
// expected-error @+1 {{func argument at index 1 has uses that were not able to be converted}}
func private @f(%arg0: !torch.nn.Module<"c">, %arg1: !torch.nn.Module<"c">) {
// expected-note @+1 {{see user here}}
%0 = basicpy.build_list %arg1 : (!torch.nn.Module<"c">) -> !basicpy.ListType
return
}
torch.nn_module {} : !torch.nn.Module<"c">
// -----
// expected-error @+1 {{reachable by multiple paths from root object: '<root>.m' and '<root>.m2'}}
torch.class_type @child {
torch.attr "float" : f64
}
torch.class_type @parent {
torch.attr "m" : !torch.nn.Module<"child">
torch.attr "m2" : !torch.nn.Module<"child">
}
%c42 = std.constant 42.0 : f64
%child = torch.nn_module {
torch.slot "float", %c42 : f64
} : !torch.nn.Module<"child">
%parent = torch.nn_module {
torch.slot "m", %child : !torch.nn.Module<"child">
torch.slot "m2", %child : !torch.nn.Module<"child">
} : !torch.nn.Module<"parent">

View File

@ -0,0 +1,39 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
torch.attr "float" : f64
torch.method "test_get", @test_get
torch.method "test_set", @test_set
torch.method "test_call", @test_call
}
// CHECK-LABEL: func @test_get() -> f64 {
// CHECK: %[[V:.*]] = torch.global_slot.get @float : f64
// CHECK: return %[[V]] : f64
func private @test_get(%arg0: !torch.nn.Module<"c">) -> f64 {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64
return %0 : f64
}
// CHECK-LABEL: func @test_set(
// CHECK-SAME: %[[A:.*]]: f64) {
// CHECK: torch.global_slot.set @float = %[[A]] : f64
// CHECK: return
func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: f64) {
torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, f64
return
}
// CHECK-LABEL: func @test_call(
// CHECK-SAME: %[[A:.*]]: f64) -> f64 {
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (f64) -> f64
// CHECK: return %[[V]] : f64
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
%0 = torch.prim.CallMethod %arg0["test_call"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64
return %0 : f64
}
%c42 = std.constant 42.0 : f64
torch.nn_module {
torch.slot "float", %c42 : f64
} : !torch.nn.Module<"c">

View File

@ -0,0 +1,25 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Check that linkage names consist of the dotted path from the root.
// CHECK-LABEL: torch.global_slot @m.float : f64
// CHECK-LABEL: func @__torch_global_slot_initializer() {
// CHECK: %[[C42:.*]] = constant 4.200000e+01 : f64
// CHECK: torch.global_slot.set @m.float = %[[C42]] : f64
// CHECK: return
torch.class_type @child {
torch.attr "float" : f64
}
torch.class_type @parent {
torch.attr "m" : !torch.nn.Module<"child">
}
%c42 = std.constant 42.0 : f64
%child = torch.nn_module {
torch.slot "float", %c42 : f64
} : !torch.nn.Module<"child">
%parent = torch.nn_module {
torch.slot "m", %child : !torch.nn.Module<"child">
} : !torch.nn.Module<"parent">

View File

@ -0,0 +1,63 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Basic case.
// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType
// CHECK: torch.global_slot @i : i64
// CHECK: torch.global_slot @f : f64
// CHECK: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype>
// CHECK-LABEL: func @__torch_global_slot_initializer() {
// CHECK: %[[CB:.*]] = basicpy.bool_constant true
// CHECK: torch.global_slot.set @b = %[[CB]] : !basicpy.BoolType
// CHECK: %[[CI:.*]] = basicpy.numeric_constant 3 : i64
// CHECK: torch.global_slot.set @i = %[[CI]] : i64
// CHECK: %[[CF:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
// CHECK: torch.global_slot.set @f = %[[CF]] : f64
// CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
// CHECK: %[[CA:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: torch.global_slot.set @a = %[[CA]] : !numpy.ndarray<*:!numpy.any_dtype>
// CHECK: return
torch.class_type @c {
torch.attr "b" : !basicpy.BoolType
torch.attr "i" : i64
torch.attr "f" : f64
torch.attr "a" : !numpy.ndarray<*:!numpy.any_dtype>
}
%bool_true = basicpy.bool_constant true
%i = basicpy.numeric_constant 3 : i64
%f = basicpy.numeric_constant 4.250000e+01 : f64
%cst = constant dense<1.000000e+00> : tensor<1xf32>
%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
torch.nn_module {
torch.slot "b", %bool_true : !basicpy.BoolType
torch.slot "i", %i : i64
torch.slot "f", %f : f64
torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype>
} : !torch.nn.Module<"c">
// -----
// Same SSA value used as initializer for multiple slots.
// CHECK-LABEL: torch.global_slot @b1 : !basicpy.BoolType
// CHECK-LABEL: torch.global_slot @b2 : !basicpy.BoolType
// CHECK-LABEL: func @__torch_global_slot_initializer() {
// CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
// CHECK: torch.global_slot.set @b1 = %[[TRUE]] : !basicpy.BoolType
// CHECK: torch.global_slot.set @b2 = %[[TRUE]] : !basicpy.BoolType
// CHECK: return
// CHECK: }
torch.class_type @c {
torch.attr "b1" : !basicpy.BoolType
torch.attr "b2" : !basicpy.BoolType
}
%bool_true = basicpy.bool_constant true
torch.nn_module {
torch.slot "b1", %bool_true : !basicpy.BoolType
torch.slot "b2", %bool_true : !basicpy.BoolType
} : !torch.nn.Module<"c">

View File

@ -2,14 +2,98 @@
// ----- // -----
torch.nn_module { torch.class_type @c {}
// expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}} %0 = torch.nn_module {
// expected-error @+1 {{'func' op is not allowed inside 'torch.nn_module'}}
func @f()
} : !torch.nn.Module<"c">
// -----
torch.class_type @c {}
%c0 = constant 0 : i64
// expected-error @+1 {{number of 'torch.slot's in a 'torch.nn_module' must match number of 'torch.attr's in the corresponding 'torch.class_type'}}
%0 = torch.nn_module {
torch.slot "f", %c0 : i64
} : !torch.nn.Module<"c">
// -----
torch.class_type @c {
// expected-note @+1 {{see torch.attr at corresponding index 0 here}}
torch.attr "g" : i64
}
%c0 = constant 0 : i64
%0 = torch.nn_module {
// expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : i64'}}
torch.slot "f", %c0 : i64
} : !torch.nn.Module<"c">
// -----
torch.class_type @c {
// expected-error @+1 {{'func' op is not allowed inside `torch.class_type`}}
func @f() func @f()
} }
// ----- // -----
torch.nn_module { // expected-error @+1 {{has duplicate attr/method with name 'a'}}
// expected-error @+1 {{'invalidSym' does not reference a valid function}} torch.class_type @c {
// expected-note @+1 {{see first conflicting attr/method here}}
torch.attr "a" : i64
// expected-note @+1 {{see second conflicting attr/method here}}
torch.attr "a" : i64
}
// -----
torch.class_type @c {
// expected-error @+1 {{'@invalidSym' does not reference a valid function}}
torch.method "f", @invalidSym torch.method "f", @invalidSym
} }
// -----
torch.class_type @c {
// expected-error @+1 {{'@f' must reference a private function}}
torch.method "f", @f
}
func @f(%arg0: !torch.nn.Module<"c">) {
return
}
// -----
torch.class_type @c {
// expected-error @+1 {{'@f' must reference a function that is defined (not merely declared)}}
torch.method "f", @f
}
func private @f(%arg0: !torch.nn.Module<"c">)
// -----
func private @f() {
return
}
torch.class_type @c {
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
torch.method "f", @f
}
// -----
func private @f(!torch.nn.Module<"other_c">) {
return
}
torch.class_type @c {
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
torch.method "f", @f
}
// -----
// expected-error @+1 {{'a' does not reference a valid class type}}
%m = torch.nn_module {} : !torch.nn.Module<"a">

View File

@ -14,16 +14,28 @@ func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
%num = basicpy.numeric_constant 4.250000e+01 : f64 %num = basicpy.numeric_constant 4.250000e+01 : f64
%cst = constant dense<1.000000e+00> : tensor<1xf32> %cst = constant dense<1.000000e+00> : tensor<1xf32>
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> %array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
func @f(%arg0: !torch.nn.Module) { %none = basicpy.singleton : !basicpy.NoneType
func private @f(%arg0: !torch.nn.Module<"test">) {
return return
} }
%submodule = torch.nn_module {}
torch.nn_module { torch.class_type @empty {}
torch.attr "b", %bool_true : !basicpy.BoolType %submodule = torch.nn_module {} : !torch.nn.Module<"empty">
torch.attr "i", %num3_i64 : i64
torch.attr "f", %num : f64 torch.class_type @test {
torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype> torch.attr "b" : !basicpy.BoolType
torch.attr "submodule", %submodule : !torch.nn.Module torch.attr "i" : i64
torch.attr "f" : f64
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
torch.attr "submodule" : !torch.nn.Module<"empty">
torch.attr "ob" : !torch.optional<!basicpy.BoolType>
torch.method "method", @f torch.method "method", @f
} }
torch.nn_module {
torch.slot "b", %bool_true : !basicpy.BoolType
torch.slot "i", %num3_i64 : i64
torch.slot "f", %num : f64
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
torch.slot "ob", %none : !basicpy.NoneType
} : !torch.nn.Module<"test">