mirror of https://github.com/llvm/torch-mlir
[cleanup] Put the root class type for exportPath first.
This is more consistent and intuitive -- usually the object being "indexed" or used as a "context" for a later parameter goes first.pull/200/head
parent
e749074bae
commit
c3f1f8ebf4
|
@ -88,10 +88,10 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator,
|
|||
|
||||
void ClassAnnotator::exportNone(c10::ClassType &rootClassType) {
|
||||
exportNoneRecurse(*this, &rootClassType);
|
||||
}
|
||||
}
|
||||
|
||||
void ClassAnnotator::exportPath(std::vector<std::string> exportedPath,
|
||||
c10::ClassType &rootClassType) {
|
||||
void ClassAnnotator::exportPath(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> exportedPath) {
|
||||
if (exportedPath.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"Empty exported path. Can only export a property of a class.");
|
||||
|
|
|
@ -123,8 +123,8 @@ public:
|
|||
// For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should
|
||||
// have a submodule `a` and that submodule should have a method or attribute
|
||||
// `b`.
|
||||
void exportPath(std::vector<std::string> exportedPath,
|
||||
c10::ClassType &rootClassType);
|
||||
void exportPath(c10::ClassType &rootClassType,
|
||||
std::vector<std::string> exportedPath);
|
||||
// Mark everything as not-exported.
|
||||
//
|
||||
// This is kind of useless by itself, but together with `exportPath` allows
|
||||
|
|
|
@ -41,8 +41,8 @@ annotator = torch_mlir.ClassAnnotator()
|
|||
class_type = recursivescriptmodule._c._type()
|
||||
|
||||
annotator.exportNone(class_type)
|
||||
annotator.exportPath(['s', 'exported'], class_type)
|
||||
annotator.exportPath(['s', 'forward'], class_type)
|
||||
annotator.exportPath(class_type, ['s', 'exported'])
|
||||
annotator.exportPath(class_type, ['s', 'forward'])
|
||||
annotator.annotateShapesAndDtypes(class_type, ['forward'], [
|
||||
None,
|
||||
((1024, 2), torch.float32),
|
||||
|
|
|
@ -24,18 +24,18 @@ annotator = torch_mlir.ClassAnnotator()
|
|||
class_type = recursivescriptmodule._c._type()
|
||||
|
||||
try:
|
||||
annotator.exportPath(['a'], class_type)
|
||||
annotator.exportPath(class_type, ['a'])
|
||||
except Exception as e:
|
||||
# CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a'
|
||||
print(e)
|
||||
try:
|
||||
annotator.exportPath([], class_type)
|
||||
annotator.exportPath(class_type, [])
|
||||
except Exception as e:
|
||||
# CHECK: Empty exported path. Can only export a property of a class.
|
||||
print(e)
|
||||
|
||||
try:
|
||||
annotator.exportPath(['a', 'b'], class_type)
|
||||
annotator.exportPath(class_type, ['a', 'b'])
|
||||
except Exception as e:
|
||||
# This error is generated by PyTorch itself, so be a bit defensive about changes.
|
||||
# CHECK: __torch__.TestModule {{.*}} 'a'
|
||||
|
|
|
@ -41,8 +41,8 @@ class_type = recursivescriptmodule._c._type()
|
|||
# CHECK: torch.method private "not_exported_method", @{{.*}}
|
||||
# CHECK: }
|
||||
annotator.exportNone(class_type)
|
||||
annotator.exportPath(['s', 'exported'], class_type)
|
||||
annotator.exportPath(['s', 'forward'], class_type)
|
||||
annotator.exportPath(class_type, ['s', 'exported'])
|
||||
annotator.exportPath(class_type, ['s', 'forward'])
|
||||
|
||||
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, annotator)
|
||||
|
|
|
@ -33,8 +33,8 @@ class_type = recursivescriptmodule._c._type()
|
|||
# CHECK: torch.method private "not_exported_method", @{{.*}}
|
||||
# CHECK: }
|
||||
annotator.exportNone(class_type)
|
||||
annotator.exportPath(['exported'], class_type)
|
||||
annotator.exportPath(['forward'], class_type)
|
||||
annotator.exportPath(class_type, ['exported'])
|
||||
annotator.exportPath(class_type, ['forward'])
|
||||
|
||||
# # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, annotator)
|
||||
|
|
|
@ -44,7 +44,7 @@ Can pass repeatedly.
|
|||
if args.exported_name is not None:
|
||||
class_annotator.exportNone(module._c._type())
|
||||
for name in args.exported_name:
|
||||
class_annotator.exportPath(name.split("."), module._c._type())
|
||||
class_annotator.exportPath(module._c._type(), name.split("."))
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
mb.import_module(module._c, class_annotator)
|
||||
mb.module.operation.print(large_elements_limit=16)
|
||||
|
|
Loading…
Reference in New Issue