[cleanup] Put the root class type for exportPath first.

This is more consistent and intuitive -- usually the object being
"indexed" or used as a "context" for a later parameter goes first.
pull/200/head
Sean Silva 2021-03-30 16:38:19 -07:00
parent e749074bae
commit c3f1f8ebf4
7 changed files with 15 additions and 15 deletions

View File

@ -88,10 +88,10 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator,
void ClassAnnotator::exportNone(c10::ClassType &rootClassType) {
exportNoneRecurse(*this, &rootClassType);
}
}
void ClassAnnotator::exportPath(std::vector<std::string> exportedPath,
c10::ClassType &rootClassType) {
void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
std::vector<std::string> exportedPath) {
if (exportedPath.size() == 0) {
throw std::invalid_argument(
"Empty exported path. Can only export a property of a class.");

View File

@ -123,8 +123,8 @@ public:
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
// have a submodule `a` and that submodule should have a method or attribute
// `b`.
void exportPath(std::vector<std::string> exportedPath,
c10::ClassType &rootClassType);
void exportPath(c10::ClassType &rootClassType,
std::vector<std::string> exportedPath);
// Mark everything as not-exported.
//
// This is kind of useless by itself, but together with `exportPath` allows

View File

@ -41,8 +41,8 @@ annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
annotator.exportNone(class_type)
annotator.exportPath(['s', 'exported'], class_type)
annotator.exportPath(['s', 'forward'], class_type)
annotator.exportPath(class_type, ['s', 'exported'])
annotator.exportPath(class_type, ['s', 'forward'])
annotator.annotateShapesAndDtypes(class_type, ['forward'], [
None,
((1024, 2), torch.float32),

View File

@ -24,18 +24,18 @@ annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()
try:
annotator.exportPath(['a'], class_type)
annotator.exportPath(class_type, ['a'])
except Exception as e:
# CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a'
print(e)
try:
annotator.exportPath([], class_type)
annotator.exportPath(class_type, [])
except Exception as e:
# CHECK: Empty exported path. Can only export a property of a class.
print(e)
try:
annotator.exportPath(['a', 'b'], class_type)
annotator.exportPath(class_type, ['a', 'b'])
except Exception as e:
# This error is generated by PyTorch itself, so be a bit defensive about changes.
# CHECK: __torch__.TestModule {{.*}} 'a'

View File

@ -41,8 +41,8 @@ class_type = recursivescriptmodule._c._type()
# CHECK: torch.method private "not_exported_method", @{{.*}}
# CHECK: }
annotator.exportNone(class_type)
annotator.exportPath(['s', 'exported'], class_type)
annotator.exportPath(['s', 'forward'], class_type)
annotator.exportPath(class_type, ['s', 'exported'])
annotator.exportPath(class_type, ['s', 'forward'])
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, annotator)

View File

@ -33,8 +33,8 @@ class_type = recursivescriptmodule._c._type()
# CHECK: torch.method private "not_exported_method", @{{.*}}
# CHECK: }
annotator.exportNone(class_type)
annotator.exportPath(['exported'], class_type)
annotator.exportPath(['forward'], class_type)
annotator.exportPath(class_type, ['exported'])
annotator.exportPath(class_type, ['forward'])
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, annotator)

View File

@ -44,7 +44,7 @@ Can pass repeatedly.
if args.exported_name is not None:
class_annotator.exportNone(module._c._type())
for name in args.exported_name:
class_annotator.exportPath(name.split("."), module._c._type())
class_annotator.exportPath(module._c._type(), name.split("."))
mb = torch_mlir.ModuleBuilder()
mb.import_module(module._c, class_annotator)
mb.module.operation.print(large_elements_limit=16)