mirror of https://github.com/llvm/torch-mlir
Implement GlobalizeObjectGraph transformation.
This required restructuring of how we model TorchScript on import. The main difference is that now we split out a `torch.class_type` that holds methods and declarations of the types of each slot. This is more consistent with TorchScript (our previous representation was "denormalized"). Recommended reading order: 1. check out the description of `torch.class_type` in `TorchOps.td` and look at `test/Dialect/Torch/ops.mlir` and `frontends/pytorch/test/module_import/` to familiarize with the new representation. - Just look at the new IR. The diff between the old names and new names is confusing. 2. check out `test/Dialect/Torch/globalize-object-graph*.mlir` and read along with the pass description in `include/npcomp/Dialect/Torch/Transforms/Passes.td` 3. Read the code in `GlobalizeObjectGraph.cpp` and miscellaneous changes in `ivalue_importer.cpp`, `TorchOps.cpp`, etc.pull/162/head
parent
99d1db18d2
commit
158c5c484d
|
@ -101,7 +101,8 @@ public:
|
||||||
private:
|
private:
|
||||||
MlirValue rawImportIValue(c10::IValue value);
|
MlirValue rawImportIValue(c10::IValue value);
|
||||||
MlirValue importModule(torch::jit::Module jitModule);
|
MlirValue importModule(torch::jit::Module jitModule);
|
||||||
void importMethod(torch::jit::Function *function, MlirBlock nnModuleBody);
|
void importMethod(torch::jit::Function *function, MlirBlock classTypeBody);
|
||||||
|
void importClassType(c10::ClassType *classType);
|
||||||
|
|
||||||
MlirBlock importBlock;
|
MlirBlock importBlock;
|
||||||
MlirContext context;
|
MlirContext context;
|
||||||
|
@ -111,6 +112,12 @@ private:
|
||||||
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
std::unordered_map<c10::IValue, MlirValue, IValueHasher, IValueEq> valueMap;
|
||||||
// Used to detect potentially aliasing tensors.
|
// Used to detect potentially aliasing tensors.
|
||||||
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
std::unordered_set<c10::StorageImpl *> seenStorageImpls;
|
||||||
|
// The set of ClassType's that have already been imported.
|
||||||
|
//
|
||||||
|
// ClassType's are referenced via their `classType->name()->qualifiedName()`
|
||||||
|
// string (as an MLIR symbol name) so we don't need to keep a map associating
|
||||||
|
// them with the MlirOperation that they import into.
|
||||||
|
std::unordered_set<c10::ClassType *> classTypes;
|
||||||
// The stack of attribute names we have traversed to reach the current IValue.
|
// The stack of attribute names we have traversed to reach the current IValue.
|
||||||
// Used for diagnostics.
|
// Used for diagnostics.
|
||||||
std::vector<std::string> attributeNameStack;
|
std::vector<std::string> attributeNameStack;
|
||||||
|
@ -128,16 +135,25 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
// TODO: Can we do better?
|
// TODO: Can we do better?
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
MlirOperation nnModule =
|
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
||||||
createMlirOperation("torch.nn_module", loc,
|
if (!maybeName) {
|
||||||
npcompNnModuleTypeGet(context), mlirRegionCreate());
|
throw std::invalid_argument("cannot import unnamed module");
|
||||||
|
}
|
||||||
|
std::string moduleTypeName = maybeName->qualifiedName();
|
||||||
|
|
||||||
|
// Ensure the class type has been imported.
|
||||||
|
importClassType(currentModule.type().get());
|
||||||
|
|
||||||
|
MlirOperation nnModule = createMlirOperation(
|
||||||
|
"torch.nn_module", loc,
|
||||||
|
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
|
|
||||||
if (!rootModuleName.has_value()) {
|
if (!rootModuleName.has_value()) {
|
||||||
c10::optional<c10::QualifiedName> maybeName = currentModule.type()->name();
|
rootModuleName = moduleTypeName;
|
||||||
rootModuleName = maybeName ? maybeName->qualifiedName() : "unnamed module";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
const std::vector<c10::IValue> &slots = currentModule._ivalue()->slots();
|
||||||
|
@ -151,7 +167,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
MlirValue slotValue = importIValue(slots[i]);
|
MlirValue slotValue = importIValue(slots[i]);
|
||||||
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
// TODO: Is it necessary to track whether an attribute is a "parameter"?
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
nnModuleBody, "torch.attr", loc, slotValue,
|
nnModuleBody, "torch.slot", loc, slotValue,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"name", mlirStringAttrGet(
|
"name", mlirStringAttrGet(
|
||||||
context, toMlirStringRef(classAttribute.getName()))));
|
context, toMlirStringRef(classAttribute.getName()))));
|
||||||
|
@ -162,10 +178,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
rootModuleName = c10::nullopt;
|
rootModuleName = c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (torch::jit::Function *function : currentModule.type()->methods()) {
|
|
||||||
importMethod(function, nnModuleBody);
|
|
||||||
}
|
|
||||||
|
|
||||||
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
||||||
mlirBlockInsertOwnedOperationBefore(
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
||||||
|
@ -262,7 +274,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void IValueImporter::importMethod(torch::jit::Function *function,
|
void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
MlirBlock nnModuleBody) {
|
MlirBlock classTypeBody) {
|
||||||
// We make an effort for the func op's symbol name to be useful for debugging,
|
// We make an effort for the func op's symbol name to be useful for debugging,
|
||||||
// but still clearly non-load-bearing.
|
// but still clearly non-load-bearing.
|
||||||
std::string symName =
|
std::string symName =
|
||||||
|
@ -275,13 +287,50 @@ void IValueImporter::importMethod(torch::jit::Function *function,
|
||||||
mlirBlockInsertOwnedOperationBefore(
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
importBlock, mlirBlockGetTerminator(importBlock), func);
|
importBlock, mlirBlockGetTerminator(importBlock), func);
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
nnModuleBody, "torch.method", mlirLocationUnknownGet(context),
|
classTypeBody, "torch.method", mlirLocationUnknownGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"name",
|
"name",
|
||||||
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
mlirStringAttrGet(context, toMlirStringRef(function->name()))),
|
||||||
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
|
toMlirNamedAttribute("function", mlirFlatSymbolRefAttrGet(
|
||||||
context, toMlirStringRef(symName))));
|
context, toMlirStringRef(symName))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
|
if (!classTypes.insert(classType).second) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Can we do better?
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
MlirOperation op = createMlirOperationAtEnd(
|
||||||
|
importBlock, "torch.class_type", loc, mlirRegionCreate(),
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"sym_name",
|
||||||
|
mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
||||||
|
MlirRegion region = mlirOperationGetRegion(op, 0);
|
||||||
|
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
||||||
|
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
||||||
|
|
||||||
|
for (const c10::ClassAttribute &classAttribute : classType->getAttributes()) {
|
||||||
|
createMlirOperationAtEnd(
|
||||||
|
classTypeBody, "torch.attr", loc,
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"name", mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(classAttribute.getName()))),
|
||||||
|
toMlirNamedAttribute("type",
|
||||||
|
mlirTypeAttrGet(typeMapper.mapFromTorchType(
|
||||||
|
loc, classAttribute.getType()))));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (torch::jit::Function *function : classType->methods()) {
|
||||||
|
importMethod(function, classTypeBody);
|
||||||
|
}
|
||||||
|
|
||||||
|
createMlirOperationAtEnd(classTypeBody, "torch.class_type_terminator", loc);
|
||||||
|
}
|
||||||
|
|
||||||
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||||
MlirContext context) {
|
MlirContext context) {
|
||||||
// When debugging module importing, it can be useful to dump as so:
|
// When debugging module importing, it can be useful to dump as so:
|
||||||
|
|
|
@ -106,11 +106,19 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
||||||
}
|
}
|
||||||
case TypeKind::ClassType: {
|
case TypeKind::ClassType: {
|
||||||
return npcompNnModuleTypeGet(context);
|
auto maybeName = torchType->cast<c10::ClassType>()->name();
|
||||||
|
return npcompNnModuleTypeGet(
|
||||||
|
context, toMlirStringRef(maybeName ? maybeName->qualifiedName()
|
||||||
|
: "unnamed class"));
|
||||||
}
|
}
|
||||||
case TypeKind::FloatType: {
|
case TypeKind::FloatType: {
|
||||||
return mlirF64TypeGet(context);
|
return mlirF64TypeGet(context);
|
||||||
}
|
}
|
||||||
|
case TypeKind::OptionalType: {
|
||||||
|
return npcompOptionalTypeGet(
|
||||||
|
mapFromTorchType(
|
||||||
|
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||||
|
}
|
||||||
case TypeKind::IntType: {
|
case TypeKind::IntType: {
|
||||||
return mlirIntegerTypeGet(context, 64);
|
return mlirIntegerTypeGet(context, 64);
|
||||||
}
|
}
|
||||||
|
@ -120,6 +128,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
case TypeKind::BoolType: {
|
case TypeKind::BoolType: {
|
||||||
return npcompBoolTypeGet(context);
|
return npcompBoolTypeGet(context);
|
||||||
}
|
}
|
||||||
|
case TypeKind::ListType: {
|
||||||
|
// TODO: Don't lose the element type information.
|
||||||
|
return npcompListTypeGet(context);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
std::stringstream message;
|
std::stringstream message;
|
||||||
message << "unable to map Torch type " << *torchType << " to MLIR type";
|
message << "unable to map Torch type " << *torchType << " to MLIR type";
|
||||||
|
|
|
@ -15,13 +15,16 @@ class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.l = [1, 2]
|
self.l = [1, 2]
|
||||||
|
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||||
|
# TODO: Don't lose element type.
|
||||||
|
# CHECK: torch.attr "l" : !basicpy.ListType
|
||||||
|
# CHECK: }
|
||||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||||
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
||||||
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
|
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
|
||||||
# CHECK: torch.nn_module {
|
# CHECK: torch.nn_module {
|
||||||
# CHECK: torch.attr "l", %[[LIST]] : !basicpy.ListType
|
# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType
|
||||||
# CHECK: }
|
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||||
|
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
|
|
@ -32,8 +32,8 @@ class TestModule(torch.nn.Module):
|
||||||
# the case that the name is `__main__` Torch replaces it with `__torch__` to
|
# the case that the name is `__main__` Torch replaces it with `__torch__` to
|
||||||
# avoid collisions.
|
# avoid collisions.
|
||||||
|
|
||||||
# CHECK: func private @__npcomp_priv_fn.__torch__.Submodule.forward
|
# CHECK-DAG: func private @__npcomp_priv_fn.__torch__.TestModule.forward
|
||||||
# CHECK: func private @__npcomp_priv_fn.__torch__.TestModule.forward
|
# CHECK=DAG: func private @__npcomp_priv_fn.__torch__.Submodule.forward
|
||||||
|
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
|
|
@ -19,17 +19,22 @@ class TestModule(torch.nn.Module):
|
||||||
|
|
||||||
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
|
# The symbol name of the function is NOT load-bearing and cannot be relied upon.
|
||||||
|
|
||||||
|
# CHECK-LABEL: torch.class_type
|
||||||
|
# CHECK-SAME: @[[CLASSTYPE:.*]] {
|
||||||
|
# CHECK: torch.method "forward", @[[SYMNAME:.*]]
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func private
|
# CHECK-LABEL: func private
|
||||||
# CHECK-SAME: @[[SYMNAME:.*]](
|
# CHECK-SAME: @[[SYMNAME]](
|
||||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"[[CLASSTYPE]]">,
|
||||||
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||||
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
|
# CHECK: %[[RET:.*]] = torch.kernel_call "aten::mul" %[[X]], %[[Y]]
|
||||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
|
||||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.method "forward", @[[SYMNAME]]
|
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||||
# CHECK: }
|
|
||||||
|
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
|
|
@ -24,8 +24,8 @@ class TestModule(torch.nn.Module):
|
||||||
# CHECK: %[[L2:.*]] = basicpy.build_list
|
# CHECK: %[[L2:.*]] = basicpy.build_list
|
||||||
# CHECK: %[[L1:.*]] = basicpy.build_list
|
# CHECK: %[[L1:.*]] = basicpy.build_list
|
||||||
# CHECK: torch.nn_module {
|
# CHECK: torch.nn_module {
|
||||||
# CHECK: torch.attr "l2", %[[L2]]
|
# CHECK: torch.slot "l2", %[[L2]]
|
||||||
# CHECK: torch.attr "l1", %[[L1]]
|
# CHECK: torch.slot "l1", %[[L1]]
|
||||||
self.l2 = self.l1 = [1]
|
self.l2 = self.l1 = [1]
|
||||||
|
|
||||||
# This can be uncommented when the graph importer supports it.
|
# This can be uncommented when the graph importer supports it.
|
||||||
|
|
|
@ -16,8 +16,8 @@ class TestModule(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# CHECK: %[[A:.*]] = numpy.create_array_from_tensor
|
# CHECK: %[[A:.*]] = numpy.create_array_from_tensor
|
||||||
# CHECK: torch.nn_module {
|
# CHECK: torch.nn_module {
|
||||||
# CHECK: torch.attr "t1", %[[A]]
|
# CHECK: torch.slot "t1", %[[A]]
|
||||||
# CHECK: torch.attr "t2", %[[A]]
|
# CHECK: torch.slot "t2", %[[A]]
|
||||||
self.t1 = self.t2 = torch.tensor([10., 20.])
|
self.t1 = self.t2 = torch.tensor([10., 20.])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ class TestModule(torch.nn.Module):
|
||||||
self.t2 = torch.ones(1)
|
self.t2 = torch.ones(1)
|
||||||
|
|
||||||
# CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}(
|
# CHECK-LABEL: func{{.*}}TestModule.forward{{.*}}(
|
||||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module) -> !basicpy.NoneType {
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !basicpy.NoneType {
|
||||||
def forward(self):
|
def forward(self):
|
||||||
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
|
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
|
||||||
# CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]]
|
# CHECK: torch.prim.SetAttr %[[SELF]]["t1"] = %[[T2]]
|
||||||
|
@ -26,7 +26,7 @@ class TestModule(torch.nn.Module):
|
||||||
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
|
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
|
||||||
self.callee(self.t1, self.t2)
|
self.callee(self.t1, self.t2)
|
||||||
# CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}(
|
# CHECK-LABEL: func{{.*}}TestModule.callee{{.*}}(
|
||||||
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module,
|
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
|
||||||
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
# CHECK-SAME: %[[X:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||||
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK-SAME: %[[Y:.*]]: !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
def callee(self, x, y):
|
def callee(self, x, y):
|
||||||
|
|
|
@ -17,14 +17,20 @@ class TestModule(torch.nn.Module):
|
||||||
self.i = 3
|
self.i = 3
|
||||||
self.f = 42.5
|
self.f = 42.5
|
||||||
|
|
||||||
|
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||||
|
# CHECK: torch.attr "training" : !basicpy.BoolType
|
||||||
|
# CHECK: torch.attr "i" : i64
|
||||||
|
# CHECK: torch.attr "f" : f64
|
||||||
|
# CHECK: }
|
||||||
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||||
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
|
# CHECK: %[[N3:.*]] = basicpy.numeric_constant 3 : i64
|
||||||
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
# CHECK: %[[N42:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
# CHECK: %[[MODULE:.*]] = torch.nn_module {
|
# CHECK: %[[MODULE:.*]] = torch.nn_module {
|
||||||
# Note: for some reason, Torch always adds a "training" property to all modules.
|
# Note: for some reason, Torch always adds a "training" property to all modules.
|
||||||
# CHECK: torch.attr "training", %[[TRUE]] : !basicpy.BoolType
|
# CHECK: torch.slot "training", %[[TRUE]] : !basicpy.BoolType
|
||||||
# CHECK: torch.attr "i", %[[N3]] : i64
|
# CHECK: torch.slot "i", %[[N3]] : i64
|
||||||
# CHECK: torch.attr "f", %[[N42]] : f64
|
# CHECK: torch.slot "f", %[[N42]] : f64
|
||||||
|
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE:.*]]">
|
||||||
|
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
|
|
@ -15,6 +15,8 @@ class Submodule(torch.nn.Module):
|
||||||
def __init__(self, n):
|
def __init__(self, n):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n = n
|
self.n = n
|
||||||
|
def forward(self):
|
||||||
|
return self.n
|
||||||
|
|
||||||
class TestModule(torch.nn.Module):
|
class TestModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -27,9 +29,9 @@ class TestModule(torch.nn.Module):
|
||||||
# Modules with the same class can be selected between.
|
# Modules with the same class can be selected between.
|
||||||
# CHECK: %[[MOD:.*]] = scf.if
|
# CHECK: %[[MOD:.*]] = scf.if
|
||||||
s = self.s1 if b else self.s2
|
s = self.s1 if b else self.s2
|
||||||
# CHECK: %[[N:.*]] = torch.prim.GetAttr %5["n"]
|
# CHECK: %[[N:.*]] = torch.prim.CallMethod %[[MOD]]["forward"] ()
|
||||||
# CHECK: return %[[N]]
|
# CHECK: return %[[N]]
|
||||||
return s.n
|
return s.forward()
|
||||||
|
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
|
|
|
@ -26,20 +26,20 @@ class TestModule(torch.nn.Module):
|
||||||
|
|
||||||
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
|
# CHECK: %[[N0:.*]] = basicpy.numeric_constant 0 : i64
|
||||||
# CHECK: %[[S0:.*]] = torch.nn_module {
|
# CHECK: %[[S0:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||||
# CHECK: torch.attr "n", %[[N0]] : i64
|
# CHECK: torch.slot "n", %[[N0]] : i64
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
|
||||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||||
# CHECK: %[[S1:.*]] = torch.nn_module {
|
# CHECK: %[[S1:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||||
# CHECK: torch.attr "n", %[[N1]] : i64
|
# CHECK: torch.slot "n", %[[N1]] : i64
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
|
||||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.attr "training", %[[T]] : !basicpy.BoolType
|
# CHECK: torch.slot "training", %[[T]] : !basicpy.BoolType
|
||||||
# CHECK: torch.attr "s0", %[[S0]] : !torch.nn.Module
|
# CHECK: torch.slot "s0", %[[S0]] : !torch.nn.Module
|
||||||
# CHECK: torch.attr "s1", %[[S1]] : !torch.nn.Module
|
# CHECK: torch.slot "s1", %[[S1]] : !torch.nn.Module
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,8 @@ class TestModule(torch.nn.Module):
|
||||||
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
# CHECK: %[[CT:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: %[[T:.*]] = numpy.create_array_from_tensor %[[CT]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||||
# CHECK: torch.attr "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: torch.slot "p", %[[P]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
# CHECK: torch.attr "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: torch.slot "t", %[[T]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -124,8 +124,18 @@ MlirType npcompTupleTypeGet(MlirContext context);
|
||||||
/** Checks whether the given type is a torch.nn.Module type */
|
/** Checks whether the given type is a torch.nn.Module type */
|
||||||
int npcompTypeIsANnModule(MlirType t);
|
int npcompTypeIsANnModule(MlirType t);
|
||||||
|
|
||||||
/** Gets the singleton torch.nn.Module type. */
|
/** Gets the !torch.nn.Module type of the specified class. */
|
||||||
MlirType npcompNnModuleTypeGet(MlirContext context);
|
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.optional type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.optional<T> type */
|
||||||
|
int npcompTypeIsAOptional(MlirType t);
|
||||||
|
|
||||||
|
/** Gets the !torch.optional<T> type with subtype T. */
|
||||||
|
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
|
|
@ -49,10 +49,11 @@ def Torch_KernelCallOp : Torch_Op<"kernel_call", [
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TorchScript modeling ops.
|
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
||||||
let summary = "Constructs a torch.nn.Module";
|
let summary = "Constructs a torch.nn.Module";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -65,15 +66,19 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
%2 = torch.nn_module {
|
%2 = torch.nn_module {
|
||||||
torch.attr "b", %bool_true : !basicpy.BoolType
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||||
torch.attr "i", %num3_i64 : i64
|
torch.slot "i", %num3_i64 : i64
|
||||||
torch.attr "f", %num : f64
|
torch.slot "f", %num : f64
|
||||||
torch.attr "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
torch.slot "t", %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
torch.attr "submodule", %1 : !torch.nn.Module
|
torch.slot "submodule", %1 : !torch.nn.Module
|
||||||
torch.method "method", @f
|
} : !torch.nn.Module<"my_class_name">
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This op is tightly coupled to the `torch.class_type` op named in the
|
||||||
|
`!torch.nn.Module<"my_class_name">` type. Each slot must match precisely
|
||||||
|
with the corresponding `torch.attr` in the `torch.class_type`.
|
||||||
|
See the documentation for `torch.class_type` for information.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
|
@ -81,7 +86,11 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
let regions = (region SizedRegion<1>:$region);
|
let regions = (region SizedRegion<1>:$region);
|
||||||
let verifier = "return ::verify(*this);";
|
let verifier = "return ::verify(*this);";
|
||||||
|
|
||||||
let assemblyFormat = "$region attr-dict";
|
let assemblyFormat = "$region attr-dict `:` type($result)";
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
StringRef getClassName() { return getType().getClassName(); }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||||
|
@ -94,13 +103,13 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||||
let assemblyFormat = "attr-dict";
|
let assemblyFormat = "attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AttrOp : Torch_Op<"attr", [
|
def Torch_SlotOp : Torch_Op<"slot", [
|
||||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||||
let summary = "Define an attribute of a torch.nn.Module";
|
let summary = "Define the value of a slot of a torch.nn.Module";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op declaratively specifies that the parent torch.nn_module has an
|
This op specifies that the initial value of the slot `name` of the
|
||||||
attribute `name` with value `value`, which is allowed to be an arbitrary
|
parent torch.nn_module should be `value`, which is allowed to be an
|
||||||
Torch-compatible SSA value, including other torch.nn.Module's.
|
arbitrary Torch-compatible SSA value, including other !torch.nn.Module's.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
|
let arguments = (ins StrAttr:$name, AnyTorchType:$value);
|
||||||
|
@ -111,13 +120,73 @@ def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Modeling of TorchScript class types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
||||||
|
Symbol,
|
||||||
|
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
|
||||||
|
let summary = "Constructs a torch.ClassType";
|
||||||
|
let description = [{
|
||||||
|
Declares a class type. Class types are the types used to describe
|
||||||
|
TorchScript `torch.nn.Module`'s. The terminology "class type" is for
|
||||||
|
consistency with TorchScript (a better name in our context might be
|
||||||
|
"nn module subtype"). The `syn_name` of this op is the same string
|
||||||
|
as in the `!torch.nn.Module<"...">` type.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
// A simple empty torch.class_type, with corresponding torch.nn_module.
|
||||||
|
torch.class_type @empty {}
|
||||||
|
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
|
||||||
|
|
||||||
|
// A class type with many members.
|
||||||
|
torch.class_type @test {
|
||||||
|
torch.attr "b" : !basicpy.BoolType
|
||||||
|
torch.attr "i" : i64
|
||||||
|
torch.attr "f" : f64
|
||||||
|
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||||
|
torch.method "method", @f
|
||||||
|
}
|
||||||
|
torch.nn_module {
|
||||||
|
// These must match the order and names in the `torch.class_type`.
|
||||||
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||||
|
torch.slot "i", %num3_i64 : i64
|
||||||
|
torch.slot "f", %num : f64
|
||||||
|
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||||
|
} : !torch.nn.Module<"test">
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins SymbolNameAttr:$sym_name);
|
||||||
|
let results = (outs);
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
let verifier = "return ::verify(*this);";
|
||||||
|
|
||||||
|
let assemblyFormat = "$sym_name $region attr-dict";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
|
||||||
|
let summary = "Implicit terminator for torch.class_type";
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_MethodOp : Torch_Op<"method", [
|
def Torch_MethodOp : Torch_Op<"method", [
|
||||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">,
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
|
||||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||||
]> {
|
]> {
|
||||||
let summary = "Define a method of a torch.nn.Module";
|
let summary = "Declare a method of a torch.class_type";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op declaratively specifies that the parent torch.nn_module has a
|
This op declaratively specifies that the parent torch.class_type has a
|
||||||
method `name` which calls `function`. `function` is an unbound function.
|
method `name` which calls `function`. `function` is an unbound function.
|
||||||
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
|
That is, it explicitly takes the torch.nn.Module as a parameter (no implicit
|
||||||
"self" object).
|
"self" object).
|
||||||
|
@ -131,6 +200,88 @@ def Torch_MethodOp : Torch_Op<"method", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
|
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
|
||||||
|
]> {
|
||||||
|
let summary = "Declare an attribute of a torch.class_type";
|
||||||
|
let description = [{
|
||||||
|
This op declaratively specifies that torch.nn.Module's of the parent
|
||||||
|
torch.class_type must have an attribute `name` of type `type`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins StrAttr:$name, TypeAttr:$type);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$name `:` $type attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Global slot ops
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TODO: Should these be in a separate dialect?
|
||||||
|
// At this point, they are fairly specific to torch types, but their get/set
|
||||||
|
// semantics follow Python.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
|
Symbol,
|
||||||
|
IsolatedFromAbove,
|
||||||
|
]> {
|
||||||
|
let summary = "A slot with global storage";
|
||||||
|
let description = [{
|
||||||
|
Represents a slot with global storage. The slot semantics are the same
|
||||||
|
as Python's: getting or setting a slot is done by object identity.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$typeBound);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$sym_name attr-dict `:` $typeBound
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// The name of the function, which, for semantic correctness, must be called
|
||||||
|
// exactly once and this call must be done before any other calls into
|
||||||
|
// the module.
|
||||||
|
// TODO: Avoid load-bearing names.
|
||||||
|
// We could replace this with an op that marks the function as initializer.
|
||||||
|
static constexpr StringRef getGlobalSlotInitializerFuncName() {
|
||||||
|
return "__torch_global_slot_initializer";
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> {
|
||||||
|
let summary = "Get the value stored in a torch.global_slot";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$slot
|
||||||
|
);
|
||||||
|
let results = (outs AnyTorchType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$slot attr-dict `:` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
||||||
|
let summary = "Set the value stored in a torch.global_slot";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$slot,
|
||||||
|
AnyTorchType:$value
|
||||||
|
);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$slot `=` $value attr-dict `:` type($value)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TorchScript `prim::` ops.
|
// TorchScript `prim::` ops.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -142,7 +293,7 @@ def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
||||||
let results = (outs AnyTorchType:$result);
|
let results = (outs AnyTorchType:$result);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$receiver `[` $name `]` attr-dict `:` type($result)
|
$receiver `[` $name `]` attr-dict `:` type($receiver) `->` type($result)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,7 +308,7 @@ def Torch_PrimSetAttrOp : Torch_Op<"prim.SetAttr", []> {
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$receiver `[` $name `]` `=` $value attr-dict `:` type($value)
|
$receiver `[` $name `]` `=` $value attr-dict `:` type($receiver) `,` type($value)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +323,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
|
||||||
let results = (outs AnyTorchType:$result);
|
let results = (outs AnyTorchType:$result);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($operands) `->` type($result)
|
$receiver `[` $name `]` `(` $operands `)` attr-dict `:` type($receiver) `,` functional-type($operands, $result)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,9 +22,62 @@ class Torch_Type<string name, string typeMnemonic> : TypeDef<Torch_Dialect, name
|
||||||
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||||
let summary = "torch.nn.Module";
|
let summary = "torch.nn.Module";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
Represents an instance of a `torch.nn.Module` with the given `className`.
|
||||||
|
}];
|
||||||
|
let parameters = (ins StringRefParameter<"class name">:$className);
|
||||||
|
|
||||||
|
let printer = [{
|
||||||
|
$_printer << "nn.Module<\"";
|
||||||
|
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
|
||||||
|
$_printer << "\">";
|
||||||
|
}];
|
||||||
|
|
||||||
|
let parser = [{
|
||||||
|
if (parser.parseLess())
|
||||||
|
return Type();
|
||||||
|
StringRef className;
|
||||||
|
if ($_parser.parseOptionalString(&className))
|
||||||
|
return Type();
|
||||||
|
if ($_parser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get($_ctxt, className);
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: It feels like this should be something more general.
|
||||||
|
// However, to do that, we need to agree on construction operations
|
||||||
|
// and the valid MLIR representations of the "None" state.
|
||||||
|
//
|
||||||
|
// For now, we only need it as a stand-in type to allow importing
|
||||||
|
// the `_is_full_backward_hook` optional bool type that Torch puts on
|
||||||
|
// all classes.
|
||||||
|
def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
|
||||||
|
let summary = "!torch.optional<T>";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
let parameters = (ins "::mlir::Type":$containedType);
|
||||||
|
|
||||||
|
let printer = [{
|
||||||
|
$_printer << "optional<" << getImpl()->containedType << ">";
|
||||||
|
}];
|
||||||
|
|
||||||
|
let parser = [{
|
||||||
|
if (parser.parseLess())
|
||||||
|
return Type();
|
||||||
|
Type containedType;
|
||||||
|
if ($_parser.parseType(containedType))
|
||||||
|
return Type();
|
||||||
|
if ($_parser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get($_ctxt, containedType);
|
||||||
|
}];
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||||
|
return Base::get(containedType.getContext(), containedType);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type predicates
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(NPCOMPTorchPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes -gen-pass-doc NPCOMPTorchTransforms ./)
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
|
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace Torch {
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||||
|
|
||||||
|
} // namespace Torch
|
||||||
|
|
||||||
|
/// Registers all Torch transformation passes.
|
||||||
|
void registerTorchPasses();
|
||||||
|
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,65 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_TORCH_PASSES
|
||||||
|
#define NPCOMP_TORCH_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||||
|
let summary = "Converts TorchScript object graphs to a globalized form";
|
||||||
|
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
|
||||||
|
let description = [{
|
||||||
|
This pass converts a subset of possible TorchScript modules into a
|
||||||
|
more restrictive lower-level form that strips away the need to be
|
||||||
|
concerned with instances of !torch.nn.Module<...> type. Specifically,
|
||||||
|
the object graph is flattened into a set of discrete globals
|
||||||
|
(`torch.global_slot`) that hold the program state.
|
||||||
|
|
||||||
|
The overarching goal is for a strict correspondence between the original
|
||||||
|
`torch.nn.Module` (call it `root`) that the user `torch.jit.script`'ed, and
|
||||||
|
the public interface of the resulting MLIR module. Specifically:
|
||||||
|
- The call `root.encoder.forward(...)` in Python corresponds to invoking
|
||||||
|
the `func @encoder.forward` on the resulting MLIR module.
|
||||||
|
- The data member access `root.decoder.ids_to_strings_table` in Python
|
||||||
|
corresponds to accessing the
|
||||||
|
`torch.global_slot @decoder.ids_to_strings_table` on the resulting
|
||||||
|
MLIR module.
|
||||||
|
In effect, the entire MLIR module corresponds to an instance of the `root`
|
||||||
|
object. This matches with the intuitive behavior desired for deployment:
|
||||||
|
When the MLIR module (or, more likely, a compiled artifact derived from it)
|
||||||
|
is loaded in a deployed environment, it is equivalent to recreating the
|
||||||
|
original `root` object.
|
||||||
|
|
||||||
|
This pass performs a complete change of the externally visible calling
|
||||||
|
convention of the MLIR module for a graph of objects and methods to a
|
||||||
|
fixed set of globals and functions.
|
||||||
|
|
||||||
|
Of course, only a subset of programs can be transformed, and this pass fails
|
||||||
|
with an error if the conditions are violated.
|
||||||
|
|
||||||
|
Specifically, the restrictions are:
|
||||||
|
- There must be a unique torch.nn_module that is not the value of a slot
|
||||||
|
of any other torch.nn_module
|
||||||
|
- Rationale: Allows us to have a notion of a unique "root" op, which is
|
||||||
|
used to define linkage. This also matches how TorchScript imports in
|
||||||
|
practice (`torch.jit.script` imports a single root object).
|
||||||
|
- There must be exactly one instance of each torch.class_type. Equivalently,
|
||||||
|
Every torch.nn_module must have a distinct type.
|
||||||
|
- Rationale: This guarantee precludes things like selecting between
|
||||||
|
multiple modules dynamically at runtime, which would require indirecting
|
||||||
|
between the separate storage of each instance.
|
||||||
|
- All torch.nn_module's must be reachable by a unique path from the root
|
||||||
|
- Rationale: Eliminates possibility of potentially exponential number of
|
||||||
|
paths. Or worse, infinite number of paths when considering cyclic
|
||||||
|
object graphs. Also as of Feb 2021, TorchScript won't import into
|
||||||
|
this form (it has a bug related to the identity of submodules).
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // NPCOMP_TORCH_PASSES
|
|
@ -156,7 +156,21 @@ int npcompTypeIsANnModule(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NnModuleType>();
|
return unwrap(t).isa<Torch::NnModuleType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Gets the singleton torch.nn.Module type. */
|
/** Gets the torch.nn.Module type of the specified class. */
|
||||||
MlirType npcompNnModuleTypeGet(MlirContext context) {
|
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
|
||||||
return wrap(Torch::NnModuleType::get(unwrap(context)));
|
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*============================================================================*/
|
||||||
|
/* torch.optional type. */
|
||||||
|
/*============================================================================*/
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.optional<T> type */
|
||||||
|
int npcompTypeIsAOptional(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::OptionalType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the !torch.optional<T> type with subtype T. */
|
||||||
|
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
||||||
|
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -49,8 +50,23 @@ KernelMetadata KernelCallOp::getTorchKernelMetadata() {
|
||||||
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
|
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
|
||||||
if (!func)
|
if (!func)
|
||||||
return emitError() << "'" << function()
|
return emitError() << "'@" << function()
|
||||||
<< "' does not reference a valid function";
|
<< "' does not reference a valid function";
|
||||||
|
if (func.getVisibility() != SymbolTable::Visibility::Private)
|
||||||
|
return emitError() << "'@" << function()
|
||||||
|
<< "' must reference a private function";
|
||||||
|
if (func.isDeclaration())
|
||||||
|
return emitError() << "'@" << function()
|
||||||
|
<< "' must reference a function that is defined (not "
|
||||||
|
"merely declared)";
|
||||||
|
auto expectedReceiverArgType = NnModuleType::get(
|
||||||
|
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
||||||
|
if (func.getType().getNumInputs() == 0 ||
|
||||||
|
func.getType().getInput(0) != expectedReceiverArgType) {
|
||||||
|
return emitError() << "the referenced function '" << function()
|
||||||
|
<< "' must have a first argument of type "
|
||||||
|
<< expectedReceiverArgType;
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,8 +76,82 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
|
|
||||||
static LogicalResult verify(NnModuleOp op) {
|
static LogicalResult verify(NnModuleOp op) {
|
||||||
for (Operation &child : *op.getBody())
|
for (Operation &child : *op.getBody())
|
||||||
if (!isa<AttrOp, MethodOp, NnModuleTerminatorOp>(&child))
|
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
||||||
return child.emitOpError() << "is not allowed inside `torch.nn_module`";
|
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyTorch has a well-developed notion of subtyping.
|
||||||
|
//
|
||||||
|
// This is a restricted subset of it.
|
||||||
|
//
|
||||||
|
// TODO: Flesh this out.
|
||||||
|
bool isValidSubtype(Type subtype, Type type) {
|
||||||
|
if (subtype == type)
|
||||||
|
return true;
|
||||||
|
if (auto optional = type.dyn_cast<OptionalType>())
|
||||||
|
return subtype == optional.getContainedType() ||
|
||||||
|
subtype.isa<Basicpy::NoneType>();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
|
auto classType =
|
||||||
|
symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(*this, getClassName());
|
||||||
|
if (!classType)
|
||||||
|
return emitError() << "'" << getClassName()
|
||||||
|
<< "' does not reference a valid class type";
|
||||||
|
|
||||||
|
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
|
||||||
|
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
|
||||||
|
if (attrs.size() != attrDefs.size())
|
||||||
|
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
|
||||||
|
"match number of 'torch.attr's in "
|
||||||
|
"the corresponding 'torch.class_type'";
|
||||||
|
for (int i = 0, e = attrs.size(); i != e; i++) {
|
||||||
|
SlotOp attr = attrs[i];
|
||||||
|
AttrOp attrDef = attrDefs[i];
|
||||||
|
if (!isValidSubtype(attr.value().getType(), attrDef.type()) ||
|
||||||
|
attr.name() != attrDef.name()) {
|
||||||
|
return attr.emitOpError()
|
||||||
|
.append("is expected to match type and name of '",
|
||||||
|
attrDef.getOperation(), "'")
|
||||||
|
.attachNote(attrDef.getLoc())
|
||||||
|
.append("see torch.attr at corresponding index ", i, " here");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ClassTypeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verify(ClassTypeOp op) {
|
||||||
|
llvm::StringMap<Operation *> namesToOps;
|
||||||
|
for (Operation &child : op.getBody()->without_terminator()) {
|
||||||
|
if (!isa<AttrOp, MethodOp>(&child))
|
||||||
|
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
||||||
|
StringRef name;
|
||||||
|
if (auto attr = dyn_cast<AttrOp>(child))
|
||||||
|
name = attr.name();
|
||||||
|
else
|
||||||
|
name = cast<MethodOp>(child).name();
|
||||||
|
auto itAndWasInserted = namesToOps.insert({name, &child});
|
||||||
|
auto it = itAndWasInserted.first;
|
||||||
|
bool wasInserted = itAndWasInserted.second;
|
||||||
|
if (!wasInserted) {
|
||||||
|
auto diag = op.emitOpError().append(
|
||||||
|
"has duplicate attr/method with name '", name, "'");
|
||||||
|
diag.attachNote(it->second->getLoc())
|
||||||
|
.append("see first conflicting attr/method here");
|
||||||
|
diag.attachNote(child.getLoc())
|
||||||
|
.append("see second conflicting attr/method here");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
|
Passes.cpp
|
||||||
|
GlobalizeObjectGraph.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPTorchPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
NPCOMPTorchDialect
|
||||||
|
NPCOMPBasicpyDialect
|
||||||
|
)
|
|
@ -0,0 +1,340 @@
|
||||||
|
//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "llvm/ADT/MapVector.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// See the pass documentation for `torch-globalize-object-graph`.
|
||||||
|
class ObjectGraphGlobalizer {
|
||||||
|
public:
|
||||||
|
ObjectGraphGlobalizer(ModuleOp module);
|
||||||
|
LogicalResult globalizeObjectGraph();
|
||||||
|
|
||||||
|
private:
|
||||||
|
FailureOr<NnModuleOp> findRootNnModule();
|
||||||
|
LogicalResult checkSingleInstanceOfEachClass();
|
||||||
|
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
|
||||||
|
void createInitializerFunc();
|
||||||
|
LogicalResult rewriteMethods();
|
||||||
|
void removeObjectGraph();
|
||||||
|
|
||||||
|
ModuleOp module;
|
||||||
|
SymbolTable symbolTable;
|
||||||
|
OpBuilder globalBuilder;
|
||||||
|
// The stack of attribute names we have traversed during our recursive
|
||||||
|
// traversal of the class/object hierarchy.
|
||||||
|
//
|
||||||
|
// Linkage names are calculated based on the set of attribute names traversed
|
||||||
|
// from the root class/module in the program.
|
||||||
|
SmallVector<std::string> nameStack;
|
||||||
|
|
||||||
|
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
|
||||||
|
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
|
||||||
|
// + attr name since frequently that is all one has access to and it
|
||||||
|
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
|
||||||
|
// with the corresponding name.
|
||||||
|
using AttrOfClass =
|
||||||
|
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
|
||||||
|
// The initial value associated with an attribute of a class.
|
||||||
|
// Since we only allow a single instance of a class, this is equivalent to
|
||||||
|
// the initial value of the unique slot corresponding to that attr.
|
||||||
|
DenseMap<AttrOfClass, Value> slotInitialValues;
|
||||||
|
// The inverse map of `slotInitialValues`.
|
||||||
|
// Many attributes can have the same initial value, so the value type
|
||||||
|
// is a vector.
|
||||||
|
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
|
||||||
|
|
||||||
|
// The torch.global_slot corresponding to each torch.attr/torch.slot.
|
||||||
|
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
|
||||||
|
// The linkage name (value) for the function with symbol name (key).
|
||||||
|
DenseMap<StringRef, std::string> methodLinkageNames;
|
||||||
|
|
||||||
|
// The set of class types that have already been processed.
|
||||||
|
// Used for diagnostics.
|
||||||
|
// The map value is the original path from the root that we found it at.
|
||||||
|
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
|
||||||
|
: module(module), symbolTable(module),
|
||||||
|
globalBuilder(module.getBodyRegion()) {}
|
||||||
|
|
||||||
|
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
|
||||||
|
// We require there to be a unique root !torch.nn.Module.
|
||||||
|
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
|
||||||
|
if (failed(maybeRootNnModule))
|
||||||
|
return failure();
|
||||||
|
NnModuleOp rootNnModule = *maybeRootNnModule;
|
||||||
|
if (!rootNnModule)
|
||||||
|
return module.emitError()
|
||||||
|
<< "module does not contain a root torch.nn_module";
|
||||||
|
|
||||||
|
// We require one instance of each class. That is, there is a single
|
||||||
|
// torch.nn_module for each torch.class_type.
|
||||||
|
if (failed(checkSingleInstanceOfEachClass()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
|
||||||
|
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
|
||||||
|
for (auto slot : nnModule.getOps<SlotOp>()) {
|
||||||
|
AttrOfClass attrOfClass = {classType, slot.name()};
|
||||||
|
slotInitialValues[attrOfClass] = slot.value();
|
||||||
|
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively traverse the class hierarchy, globalizing slots and
|
||||||
|
// tracking linkage names for methods.
|
||||||
|
auto rootClassType =
|
||||||
|
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
|
||||||
|
if (failed(recursivelyTraverseClassType(rootClassType)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Move all slot initial values into an initializer func.
|
||||||
|
createInitializerFunc();
|
||||||
|
|
||||||
|
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
|
||||||
|
if (failed(rewriteMethods()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Now that all we have finished converting to the new form, remove
|
||||||
|
// the original object graph.
|
||||||
|
removeObjectGraph();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
||||||
|
NnModuleOp rootNnModule;
|
||||||
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||||
|
if (!op.use_empty())
|
||||||
|
continue;
|
||||||
|
if (rootNnModule) {
|
||||||
|
op.emitError()
|
||||||
|
.append("found more than one root module (module that is not a "
|
||||||
|
"child of any other module)")
|
||||||
|
.attachNote(rootNnModule.getLoc())
|
||||||
|
.append("see other root module here");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rootNnModule = op;
|
||||||
|
}
|
||||||
|
return rootNnModule;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
|
||||||
|
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
|
||||||
|
classInstances;
|
||||||
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||||
|
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
|
||||||
|
classInstances[classType].push_back(op);
|
||||||
|
}
|
||||||
|
for (auto &p : classInstances) {
|
||||||
|
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
|
||||||
|
ArrayRef<NnModuleOp> instances = p.second;
|
||||||
|
if (instances.size() > 1) {
|
||||||
|
// TODO: Improve this diagnostic based on user use cases.
|
||||||
|
// This is a user-facing diagnostic that enforces key invariants to
|
||||||
|
// our TorchScript subset.
|
||||||
|
auto diag = classType.emitError(
|
||||||
|
"class type has more than one instance: the current TorchScript "
|
||||||
|
"supported subset only allows single instances");
|
||||||
|
for (NnModuleOp instance : instances) {
|
||||||
|
diag.attachNote(instance.getLoc()) << "see instance here";
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
|
||||||
|
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
||||||
|
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
|
||||||
|
return classType.emitError()
|
||||||
|
<< "reachable by multiple paths from root object: '<root>."
|
||||||
|
<< seenClassTypes[classType] << "' and '<root>."
|
||||||
|
<< pathToClassFromRoot << "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each attr, create a global slot for it.
|
||||||
|
for (auto attr : classType.getOps<AttrOp>()) {
|
||||||
|
nameStack.push_back(attr.name().str());
|
||||||
|
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
|
||||||
|
recursivelyTraverseClassType(
|
||||||
|
symbolTable.lookup<ClassTypeOp>(type.getClassName()));
|
||||||
|
} else {
|
||||||
|
auto linkageName = llvm::join(nameStack, ".");
|
||||||
|
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
|
||||||
|
attr->getLoc(), linkageName, TypeAttr::get(attr.type()));
|
||||||
|
AttrOfClass attrOfClass = {classType, attr.name()};
|
||||||
|
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
|
||||||
|
globalSlotForAttr[attrOfClass] = globalSlot;
|
||||||
|
}
|
||||||
|
nameStack.pop_back();
|
||||||
|
}
|
||||||
|
// For each method, track the linkage name it will eventually have.
|
||||||
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
|
nameStack.push_back(method.name().str());
|
||||||
|
auto linkageName = llvm::join(nameStack, ".");
|
||||||
|
nameStack.pop_back();
|
||||||
|
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
|
||||||
|
method.emitError() << "unbound function shared by multiple methods";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObjectGraphGlobalizer::createInitializerFunc() {
|
||||||
|
auto loc = module.getLoc();
|
||||||
|
auto func = globalBuilder.create<FuncOp>(
|
||||||
|
loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(),
|
||||||
|
globalBuilder.getFunctionType({}, {}));
|
||||||
|
OpBuilder builder(func.getContext());
|
||||||
|
Block *body = builder.createBlock(&func.getBody());
|
||||||
|
|
||||||
|
SmallVector<Operation *> opsToMove;
|
||||||
|
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
||||||
|
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, FuncOp, ModuleTerminatorOp>(
|
||||||
|
&op))
|
||||||
|
continue;
|
||||||
|
op.moveBefore(body, body->end());
|
||||||
|
for (Value result : llvm::make_early_inc_range(op.getResults())) {
|
||||||
|
auto it = slotInitialValuesInverseMap.find(result);
|
||||||
|
if (it == slotInitialValuesInverseMap.end())
|
||||||
|
continue;
|
||||||
|
for (AttrOfClass attrOfClass : it->second) {
|
||||||
|
GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass];
|
||||||
|
OpBuilder::atBlockEnd(body).create<GlobalSlotSetOp>(
|
||||||
|
globalSlot.getLoc(), globalSlot.sym_name(), result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.create<ReturnOp>(loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
|
||||||
|
DenseMap<AttrOfClass, StringRef> linkageNames;
|
||||||
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||||
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
|
auto it = methodLinkageNames.find(method.function());
|
||||||
|
if (it == methodLinkageNames.end())
|
||||||
|
continue;
|
||||||
|
linkageNames[{classType, method.name()}] = it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We only handle a small subset of ops that conform with the set of
|
||||||
|
// assumptions that allow us to globalize the object graph. Anything that
|
||||||
|
// tries to treat modules as bona-fide objects and not just namespaces
|
||||||
|
// of methods with a single instance of the corresponding type just gets
|
||||||
|
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
|
||||||
|
// of modules, or there is an scf.if selecting between modules, etc.
|
||||||
|
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
|
||||||
|
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
|
||||||
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||||
|
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
||||||
|
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
|
||||||
|
OpBuilder(primSetAttr)
|
||||||
|
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
|
||||||
|
primSetAttr.value());
|
||||||
|
primSetAttr.erase();
|
||||||
|
}
|
||||||
|
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
|
||||||
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||||
|
primGetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
||||||
|
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
|
||||||
|
auto globalSlotGet = OpBuilder(primGetAttr)
|
||||||
|
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
|
||||||
|
primGetAttr.getType(),
|
||||||
|
globalSlot.sym_name());
|
||||||
|
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
|
||||||
|
primGetAttr.erase();
|
||||||
|
}
|
||||||
|
if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
|
||||||
|
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
|
||||||
|
.getType()
|
||||||
|
.cast<NnModuleType>()
|
||||||
|
.getClassName());
|
||||||
|
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
|
||||||
|
auto call = OpBuilder(primCallMethod)
|
||||||
|
.create<CallOp>(primCallMethod.getLoc(), linkageName,
|
||||||
|
primCallMethod.getType(),
|
||||||
|
primCallMethod.operands());
|
||||||
|
primCallMethod.replaceAllUsesWith(call);
|
||||||
|
primCallMethod.erase();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||||
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
|
auto it = methodLinkageNames.find(method.function());
|
||||||
|
if (it == methodLinkageNames.end())
|
||||||
|
continue;
|
||||||
|
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
|
||||||
|
func.setVisibility(SymbolTable::Visibility::Public);
|
||||||
|
func.setName(it->second);
|
||||||
|
func.walk(rewriteOpWithNnModuleTypeOperand);
|
||||||
|
SmallVector<unsigned> argsToErase;
|
||||||
|
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||||
|
if (!arg.value().getType().isa<NnModuleType>())
|
||||||
|
continue;
|
||||||
|
if (!arg.value().use_empty()) {
|
||||||
|
// TODO: Improve this based on real user use cases.
|
||||||
|
// This is a diagnostic that users will hit if they do not conform to
|
||||||
|
// the supported subset of TorchScript.
|
||||||
|
auto diag = func.emitError().append(
|
||||||
|
"func argument at index ", arg.index(),
|
||||||
|
" has uses that were not able to be converted");
|
||||||
|
for (Operation *user : arg.value().getUsers())
|
||||||
|
diag.attachNote(user->getLoc()).append("see user here");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
argsToErase.push_back(arg.index());
|
||||||
|
}
|
||||||
|
func.eraseArguments(argsToErase);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ObjectGraphGlobalizer::removeObjectGraph() {
|
||||||
|
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
||||||
|
if (isa<ClassTypeOp, NnModuleOp>(op))
|
||||||
|
op.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class GlobalizeObjectGraphPass
|
||||||
|
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
||||||
|
return std::make_unique<GlobalizeObjectGraphPass>();
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
|
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace Torch {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace Torch
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
|
@ -0,0 +1,20 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); }
|
|
@ -21,6 +21,7 @@
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
#include "npcomp/Dialect/TCP/Transforms/Passes.h"
|
#include "npcomp/Dialect/TCP/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "npcomp/Typing/Transforms/Passes.h"
|
#include "npcomp/Typing/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "npcomp/Conversion/Passes.h"
|
#include "npcomp/Conversion/Passes.h"
|
||||||
|
@ -47,5 +48,6 @@ void mlir::NPCOMP::registerAllPasses() {
|
||||||
mlir::NPCOMP::registerNumpyPasses();
|
mlir::NPCOMP::registerNumpyPasses();
|
||||||
mlir::NPCOMP::registerTCFPasses();
|
mlir::NPCOMP::registerTCFPasses();
|
||||||
mlir::NPCOMP::registerTCPPasses();
|
mlir::NPCOMP::registerTCPPasses();
|
||||||
|
mlir::NPCOMP::registerTorchPasses();
|
||||||
mlir::NPCOMP::registerTypingPasses();
|
mlir::NPCOMP::registerTypingPasses();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||||
|
|
||||||
|
torch.class_type @c1 {}
|
||||||
|
torch.class_type @c2 {}
|
||||||
|
|
||||||
|
// expected-note @+1 {{see other root module here}}
|
||||||
|
torch.nn_module {} : !torch.nn.Module<"c1">
|
||||||
|
// expected-error @+1 {{found more than one root module (module that is not a child of any other module)}}
|
||||||
|
torch.nn_module {} : !torch.nn.Module<"c2">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error @+1 {{class type has more than one instance: the current TorchScript supported subset only allows single instances}}
|
||||||
|
torch.class_type @child {}
|
||||||
|
torch.class_type @parent {
|
||||||
|
torch.attr "m1" : !torch.nn.Module<"child">
|
||||||
|
torch.attr "m2" : !torch.nn.Module<"child">
|
||||||
|
}
|
||||||
|
|
||||||
|
// expected-note @+1 {{see instance here}}
|
||||||
|
%0 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||||
|
// expected-note @+1 {{see instance here}}
|
||||||
|
%1 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||||
|
|
||||||
|
%root = torch.nn_module {
|
||||||
|
torch.slot "m1", %0 : !torch.nn.Module<"child">
|
||||||
|
torch.slot "m2", %1 : !torch.nn.Module<"child">
|
||||||
|
} : !torch.nn.Module<"parent">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
torch.method "f", @f
|
||||||
|
}
|
||||||
|
// expected-error @+1 {{func argument at index 1 has uses that were not able to be converted}}
|
||||||
|
func private @f(%arg0: !torch.nn.Module<"c">, %arg1: !torch.nn.Module<"c">) {
|
||||||
|
// expected-note @+1 {{see user here}}
|
||||||
|
%0 = basicpy.build_list %arg1 : (!torch.nn.Module<"c">) -> !basicpy.ListType
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.nn_module {} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error @+1 {{reachable by multiple paths from root object: '<root>.m' and '<root>.m2'}}
|
||||||
|
torch.class_type @child {
|
||||||
|
torch.attr "float" : f64
|
||||||
|
}
|
||||||
|
torch.class_type @parent {
|
||||||
|
torch.attr "m" : !torch.nn.Module<"child">
|
||||||
|
torch.attr "m2" : !torch.nn.Module<"child">
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
%c42 = std.constant 42.0 : f64
|
||||||
|
%child = torch.nn_module {
|
||||||
|
torch.slot "float", %c42 : f64
|
||||||
|
} : !torch.nn.Module<"child">
|
||||||
|
%parent = torch.nn_module {
|
||||||
|
torch.slot "m", %child : !torch.nn.Module<"child">
|
||||||
|
torch.slot "m2", %child : !torch.nn.Module<"child">
|
||||||
|
} : !torch.nn.Module<"parent">
|
|
@ -0,0 +1,39 @@
|
||||||
|
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
torch.attr "float" : f64
|
||||||
|
torch.method "test_get", @test_get
|
||||||
|
torch.method "test_set", @test_set
|
||||||
|
torch.method "test_call", @test_call
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_get() -> f64 {
|
||||||
|
// CHECK: %[[V:.*]] = torch.global_slot.get @float : f64
|
||||||
|
// CHECK: return %[[V]] : f64
|
||||||
|
func private @test_get(%arg0: !torch.nn.Module<"c">) -> f64 {
|
||||||
|
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64
|
||||||
|
return %0 : f64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_set(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: f64) {
|
||||||
|
// CHECK: torch.global_slot.set @float = %[[A]] : f64
|
||||||
|
// CHECK: return
|
||||||
|
func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: f64) {
|
||||||
|
torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, f64
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_call(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: f64) -> f64 {
|
||||||
|
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (f64) -> f64
|
||||||
|
// CHECK: return %[[V]] : f64
|
||||||
|
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
|
||||||
|
%0 = torch.prim.CallMethod %arg0["test_call"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64
|
||||||
|
return %0 : f64
|
||||||
|
}
|
||||||
|
|
||||||
|
%c42 = std.constant 42.0 : f64
|
||||||
|
torch.nn_module {
|
||||||
|
torch.slot "float", %c42 : f64
|
||||||
|
} : !torch.nn.Module<"c">
|
|
@ -0,0 +1,25 @@
|
||||||
|
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// Check that linkage names consist of the dotted path from the root.
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot @m.float : f64
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||||
|
// CHECK: %[[C42:.*]] = constant 4.200000e+01 : f64
|
||||||
|
// CHECK: torch.global_slot.set @m.float = %[[C42]] : f64
|
||||||
|
// CHECK: return
|
||||||
|
|
||||||
|
torch.class_type @child {
|
||||||
|
torch.attr "float" : f64
|
||||||
|
}
|
||||||
|
torch.class_type @parent {
|
||||||
|
torch.attr "m" : !torch.nn.Module<"child">
|
||||||
|
}
|
||||||
|
|
||||||
|
%c42 = std.constant 42.0 : f64
|
||||||
|
%child = torch.nn_module {
|
||||||
|
torch.slot "float", %c42 : f64
|
||||||
|
} : !torch.nn.Module<"child">
|
||||||
|
%parent = torch.nn_module {
|
||||||
|
torch.slot "m", %child : !torch.nn.Module<"child">
|
||||||
|
} : !torch.nn.Module<"parent">
|
|
@ -0,0 +1,63 @@
|
||||||
|
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// Basic case.
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType
|
||||||
|
// CHECK: torch.global_slot @i : i64
|
||||||
|
// CHECK: torch.global_slot @f : f64
|
||||||
|
// CHECK: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||||
|
// CHECK: %[[CB:.*]] = basicpy.bool_constant true
|
||||||
|
// CHECK: torch.global_slot.set @b = %[[CB]] : !basicpy.BoolType
|
||||||
|
// CHECK: %[[CI:.*]] = basicpy.numeric_constant 3 : i64
|
||||||
|
// CHECK: torch.global_slot.set @i = %[[CI]] : i64
|
||||||
|
// CHECK: %[[CF:.*]] = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
|
// CHECK: torch.global_slot.set @f = %[[CF]] : f64
|
||||||
|
// CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
|
// CHECK: %[[CA:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: torch.global_slot.set @a = %[[CA]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
// CHECK: return
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
torch.attr "b" : !basicpy.BoolType
|
||||||
|
torch.attr "i" : i64
|
||||||
|
torch.attr "f" : f64
|
||||||
|
torch.attr "a" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
%bool_true = basicpy.bool_constant true
|
||||||
|
%i = basicpy.numeric_constant 3 : i64
|
||||||
|
%f = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
|
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
|
%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.nn_module {
|
||||||
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||||
|
torch.slot "i", %i : i64
|
||||||
|
torch.slot "f", %f : f64
|
||||||
|
torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Same SSA value used as initializer for multiple slots.
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot @b1 : !basicpy.BoolType
|
||||||
|
// CHECK-LABEL: torch.global_slot @b2 : !basicpy.BoolType
|
||||||
|
// CHECK-LABEL: func @__torch_global_slot_initializer() {
|
||||||
|
// CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||||
|
// CHECK: torch.global_slot.set @b1 = %[[TRUE]] : !basicpy.BoolType
|
||||||
|
// CHECK: torch.global_slot.set @b2 = %[[TRUE]] : !basicpy.BoolType
|
||||||
|
// CHECK: return
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
torch.attr "b1" : !basicpy.BoolType
|
||||||
|
torch.attr "b2" : !basicpy.BoolType
|
||||||
|
}
|
||||||
|
|
||||||
|
%bool_true = basicpy.bool_constant true
|
||||||
|
torch.nn_module {
|
||||||
|
torch.slot "b1", %bool_true : !basicpy.BoolType
|
||||||
|
torch.slot "b2", %bool_true : !basicpy.BoolType
|
||||||
|
} : !torch.nn.Module<"c">
|
|
@ -2,14 +2,98 @@
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
torch.nn_module {
|
torch.class_type @c {}
|
||||||
// expected-error @+1 {{'func' op is not allowed inside `torch.nn_module`}}
|
%0 = torch.nn_module {
|
||||||
|
// expected-error @+1 {{'func' op is not allowed inside 'torch.nn_module'}}
|
||||||
|
func @f()
|
||||||
|
} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {}
|
||||||
|
%c0 = constant 0 : i64
|
||||||
|
// expected-error @+1 {{number of 'torch.slot's in a 'torch.nn_module' must match number of 'torch.attr's in the corresponding 'torch.class_type'}}
|
||||||
|
%0 = torch.nn_module {
|
||||||
|
torch.slot "f", %c0 : i64
|
||||||
|
} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-note @+1 {{see torch.attr at corresponding index 0 here}}
|
||||||
|
torch.attr "g" : i64
|
||||||
|
}
|
||||||
|
%c0 = constant 0 : i64
|
||||||
|
%0 = torch.nn_module {
|
||||||
|
// expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : i64'}}
|
||||||
|
torch.slot "f", %c0 : i64
|
||||||
|
} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{'func' op is not allowed inside `torch.class_type`}}
|
||||||
func @f()
|
func @f()
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
torch.nn_module {
|
// expected-error @+1 {{has duplicate attr/method with name 'a'}}
|
||||||
// expected-error @+1 {{'invalidSym' does not reference a valid function}}
|
torch.class_type @c {
|
||||||
|
// expected-note @+1 {{see first conflicting attr/method here}}
|
||||||
|
torch.attr "a" : i64
|
||||||
|
// expected-note @+1 {{see second conflicting attr/method here}}
|
||||||
|
torch.attr "a" : i64
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{'@invalidSym' does not reference a valid function}}
|
||||||
torch.method "f", @invalidSym
|
torch.method "f", @invalidSym
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{'@f' must reference a private function}}
|
||||||
|
torch.method "f", @f
|
||||||
|
}
|
||||||
|
|
||||||
|
func @f(%arg0: !torch.nn.Module<"c">) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{'@f' must reference a function that is defined (not merely declared)}}
|
||||||
|
torch.method "f", @f
|
||||||
|
}
|
||||||
|
|
||||||
|
func private @f(%arg0: !torch.nn.Module<"c">)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func private @f() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
|
||||||
|
torch.method "f", @f
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func private @f(!torch.nn.Module<"other_c">) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
torch.class_type @c {
|
||||||
|
// expected-error @+1 {{the referenced function 'f' must have a first argument of type '!torch.nn.Module<"c">'}}
|
||||||
|
torch.method "f", @f
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error @+1 {{'a' does not reference a valid class type}}
|
||||||
|
%m = torch.nn_module {} : !torch.nn.Module<"a">
|
||||||
|
|
|
@ -14,16 +14,28 @@ func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
|
||||||
%num = basicpy.numeric_constant 4.250000e+01 : f64
|
%num = basicpy.numeric_constant 4.250000e+01 : f64
|
||||||
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
%cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||||
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
%array = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
func @f(%arg0: !torch.nn.Module) {
|
%none = basicpy.singleton : !basicpy.NoneType
|
||||||
|
func private @f(%arg0: !torch.nn.Module<"test">) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
%submodule = torch.nn_module {}
|
|
||||||
|
|
||||||
torch.nn_module {
|
torch.class_type @empty {}
|
||||||
torch.attr "b", %bool_true : !basicpy.BoolType
|
%submodule = torch.nn_module {} : !torch.nn.Module<"empty">
|
||||||
torch.attr "i", %num3_i64 : i64
|
|
||||||
torch.attr "f", %num : f64
|
torch.class_type @test {
|
||||||
torch.attr "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
torch.attr "b" : !basicpy.BoolType
|
||||||
torch.attr "submodule", %submodule : !torch.nn.Module
|
torch.attr "i" : i64
|
||||||
|
torch.attr "f" : f64
|
||||||
|
torch.attr "t" : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.attr "submodule" : !torch.nn.Module<"empty">
|
||||||
|
torch.attr "ob" : !torch.optional<!basicpy.BoolType>
|
||||||
torch.method "method", @f
|
torch.method "method", @f
|
||||||
}
|
}
|
||||||
|
torch.nn_module {
|
||||||
|
torch.slot "b", %bool_true : !basicpy.BoolType
|
||||||
|
torch.slot "i", %num3_i64 : i64
|
||||||
|
torch.slot "f", %num : f64
|
||||||
|
torch.slot "t", %array : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||||
|
torch.slot "ob", %none : !basicpy.NoneType
|
||||||
|
} : !torch.nn.Module<"test">
|
||||||
|
|
Loading…
Reference in New Issue