From c3f1f8ebf4df151d618f092f75fe81e05e26c345 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 30 Mar 2021 16:38:19 -0700 Subject: [PATCH] [cleanup] Put the root class type for exportPath first. This is more consistent and intuitive -- usually the object being "indexed" or used as a "context" for a later parameter goes first. --- frontends/pytorch/csrc/builder/class_annotator.cpp | 6 +++--- frontends/pytorch/csrc/builder/class_annotator.h | 4 ++-- .../test/ivalue_import/annotations/class-annotator-repr.py | 4 ++-- .../pytorch/test/ivalue_import/annotations/export-error.py | 6 +++--- .../test/ivalue_import/annotations/export-recursive.py | 4 ++-- frontends/pytorch/test/ivalue_import/annotations/export.py | 4 ++-- frontends/pytorch/utils/pt_util.py | 2 +- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/frontends/pytorch/csrc/builder/class_annotator.cpp b/frontends/pytorch/csrc/builder/class_annotator.cpp index 6b808d99e..c8f9dfcf4 100644 --- a/frontends/pytorch/csrc/builder/class_annotator.cpp +++ b/frontends/pytorch/csrc/builder/class_annotator.cpp @@ -88,10 +88,10 @@ static void exportNoneRecurse(ClassAnnotator &classAnnotator, void ClassAnnotator::exportNone(c10::ClassType &rootClassType) { exportNoneRecurse(*this, &rootClassType); -} +} -void ClassAnnotator::exportPath(std::vector exportedPath, - c10::ClassType &rootClassType) { +void ClassAnnotator::exportPath(c10::ClassType &rootClassType, + std::vector exportedPath) { if (exportedPath.size() == 0) { throw std::invalid_argument( "Empty exported path. Can only export a property of a class."); diff --git a/frontends/pytorch/csrc/builder/class_annotator.h b/frontends/pytorch/csrc/builder/class_annotator.h index a1e1b5d73..53e1fd83f 100644 --- a/frontends/pytorch/csrc/builder/class_annotator.h +++ b/frontends/pytorch/csrc/builder/class_annotator.h @@ -123,8 +123,8 @@ public: // For example, if `exportedPath = ['a', 'b']`, then `rootClassType` should // have a submodule `a` and that submodule should have a method or attribute // `b`. - void exportPath(std::vector exportedPath, - c10::ClassType &rootClassType); + void exportPath(c10::ClassType &rootClassType, + std::vector exportedPath); // Mark everything as not-exported. // // This is kind of useless by itself, but together with `exportPath` allows diff --git a/frontends/pytorch/test/ivalue_import/annotations/class-annotator-repr.py b/frontends/pytorch/test/ivalue_import/annotations/class-annotator-repr.py index 8547c33b3..202f71519 100644 --- a/frontends/pytorch/test/ivalue_import/annotations/class-annotator-repr.py +++ b/frontends/pytorch/test/ivalue_import/annotations/class-annotator-repr.py @@ -41,8 +41,8 @@ annotator = torch_mlir.ClassAnnotator() class_type = recursivescriptmodule._c._type() annotator.exportNone(class_type) -annotator.exportPath(['s', 'exported'], class_type) -annotator.exportPath(['s', 'forward'], class_type) +annotator.exportPath(class_type, ['s', 'exported']) +annotator.exportPath(class_type, ['s', 'forward']) annotator.annotateShapesAndDtypes(class_type, ['forward'], [ None, ((1024, 2), torch.float32), diff --git a/frontends/pytorch/test/ivalue_import/annotations/export-error.py b/frontends/pytorch/test/ivalue_import/annotations/export-error.py index cf8b1d29a..e27acb19d 100644 --- a/frontends/pytorch/test/ivalue_import/annotations/export-error.py +++ b/frontends/pytorch/test/ivalue_import/annotations/export-error.py @@ -24,18 +24,18 @@ annotator = torch_mlir.ClassAnnotator() class_type = recursivescriptmodule._c._type() try: - annotator.exportPath(['a'], class_type) + annotator.exportPath(class_type, ['a']) except Exception as e: # CHECK: class '__torch__.TestModule' does not have a method or attribute called 'a' print(e) try: - annotator.exportPath([], class_type) + annotator.exportPath(class_type, []) except Exception as e: # CHECK: Empty exported path. Can only export a property of a class. print(e) try: - annotator.exportPath(['a', 'b'], class_type) + annotator.exportPath(class_type, ['a', 'b']) except Exception as e: # This error is generated by PyTorch itself, so be a bit defensive about changes. # CHECK: __torch__.TestModule {{.*}} 'a' diff --git a/frontends/pytorch/test/ivalue_import/annotations/export-recursive.py b/frontends/pytorch/test/ivalue_import/annotations/export-recursive.py index a540e796e..73471e6f0 100644 --- a/frontends/pytorch/test/ivalue_import/annotations/export-recursive.py +++ b/frontends/pytorch/test/ivalue_import/annotations/export-recursive.py @@ -41,8 +41,8 @@ class_type = recursivescriptmodule._c._type() # CHECK: torch.method private "not_exported_method", @{{.*}} # CHECK: } annotator.exportNone(class_type) -annotator.exportPath(['s', 'exported'], class_type) -annotator.exportPath(['s', 'forward'], class_type) +annotator.exportPath(class_type, ['s', 'exported']) +annotator.exportPath(class_type, ['s', 'forward']) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, annotator) diff --git a/frontends/pytorch/test/ivalue_import/annotations/export.py b/frontends/pytorch/test/ivalue_import/annotations/export.py index 46dec5b01..3b0483cb1 100644 --- a/frontends/pytorch/test/ivalue_import/annotations/export.py +++ b/frontends/pytorch/test/ivalue_import/annotations/export.py @@ -33,8 +33,8 @@ class_type = recursivescriptmodule._c._type() # CHECK: torch.method private "not_exported_method", @{{.*}} # CHECK: } annotator.exportNone(class_type) -annotator.exportPath(['exported'], class_type) -annotator.exportPath(['forward'], class_type) +annotator.exportPath(class_type, ['exported']) +annotator.exportPath(class_type, ['forward']) # # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, annotator) diff --git a/frontends/pytorch/utils/pt_util.py b/frontends/pytorch/utils/pt_util.py index e6f1e4d89..65a37c041 100644 --- a/frontends/pytorch/utils/pt_util.py +++ b/frontends/pytorch/utils/pt_util.py @@ -44,7 +44,7 @@ Can pass repeatedly. if args.exported_name is not None: class_annotator.exportNone(module._c._type()) for name in args.exported_name: - class_annotator.exportPath(name.split("."), module._c._type()) + class_annotator.exportPath(module._c._type(), name.split(".")) mb = torch_mlir.ModuleBuilder() mb.import_module(module._c, class_annotator) mb.module.operation.print(large_elements_limit=16)