mirror of https://github.com/llvm/torch-mlir
Add resnet inference jupyter notebook.
This takes the example from torchscript_resnet18_e2e.py and puts it into a slightly cleaned up notebook form. It's still a little rough around the edges. Areas for improvement: - Installation / setup. - API usability. Also, - Add `npcomp-backend-to-iree-frontend-pipeline` since we will be adding more stuff there. - Slight cleanups.pull/273/head
parent
f71845ea75
commit
902c2e579b
|
@ -2,6 +2,7 @@
|
||||||
.vscode
|
.vscode
|
||||||
.env
|
.env
|
||||||
*.code-workspace
|
*.code-workspace
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
/build/
|
/build/
|
||||||
__pycache__
|
__pycache__
|
||||||
|
|
|
@ -51,7 +51,7 @@ Meaning of options:
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
''')
|
''')
|
||||||
parser.add_argument('--filter', default='.*', help='''
|
parser.add_argument('-f', '--filter', default='.*', help='''
|
||||||
Regular expression specifying which tests to include in this run.
|
Regular expression specifying which tests to include in this run.
|
||||||
''')
|
''')
|
||||||
parser.add_argument('-v', '--verbose',
|
parser.add_argument('-v', '--verbose',
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -15,21 +15,24 @@ import npcomp
|
||||||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree
|
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree
|
||||||
from npcomp.compiler.utils import logging
|
from npcomp.compiler.utils import logging
|
||||||
|
|
||||||
logging.enable()
|
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
def load_and_preprocess_image(url: str):
|
def load_and_preprocess_image(url: str):
|
||||||
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
headers = {
|
||||||
|
'User-Agent':
|
||||||
|
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||||
|
}
|
||||||
|
img = Image.open(requests.get(url, headers=headers,
|
||||||
|
stream=True).raw).convert("RGB")
|
||||||
# preprocessing pipeline
|
# preprocessing pipeline
|
||||||
preprocess = transforms.Compose(
|
preprocess = transforms.Compose([
|
||||||
[
|
transforms.Resize(256),
|
||||||
transforms.Resize(256),
|
transforms.CenterCrop(224),
|
||||||
transforms.CenterCrop(224),
|
transforms.ToTensor(),
|
||||||
transforms.ToTensor(),
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
std=[0.229, 0.224, 0.225]),
|
||||||
]
|
])
|
||||||
)
|
|
||||||
img_preprocessed = preprocess(img)
|
img_preprocessed = preprocess(img)
|
||||||
return torch.unsqueeze(img_preprocessed, 0)
|
return torch.unsqueeze(img_preprocessed, 0)
|
||||||
|
|
||||||
|
@ -78,6 +81,15 @@ class TestModule(torch.nn.Module):
|
||||||
return self.s.forward(x)
|
return self.s.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
image_url = (
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||||
|
)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
print("load image from " + image_url, file=sys.stderr)
|
||||||
|
img = load_and_preprocess_image(image_url)
|
||||||
|
labels = load_labels()
|
||||||
|
|
||||||
test_module = TestModule()
|
test_module = TestModule()
|
||||||
class_annotator = torch_mlir.ClassAnnotator()
|
class_annotator = torch_mlir.ClassAnnotator()
|
||||||
recursivescriptmodule = torch.jit.script(test_module)
|
recursivescriptmodule = torch.jit.script(test_module)
|
||||||
|
@ -88,7 +100,10 @@ class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
|
||||||
class_annotator.annotateArgs(
|
class_annotator.annotateArgs(
|
||||||
recursivescriptmodule._c._type(),
|
recursivescriptmodule._c._type(),
|
||||||
["forward"],
|
["forward"],
|
||||||
[None, ([-1, -1, -1, -1], torch.float32, True),],
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
|
@ -97,10 +112,4 @@ backend = refjit.RefjitNpcompBackend()
|
||||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
image_url = (
|
|
||||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
|
||||||
)
|
|
||||||
print("load image from " + image_url)
|
|
||||||
img = load_and_preprocess_image(image_url)
|
|
||||||
labels = load_labels()
|
|
||||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
predictions(test_module.forward, jit_module.forward, img, labels)
|
||||||
|
|
|
@ -18,6 +18,10 @@ namespace IREEBackend {
|
||||||
/// Registers all IREEBackend passes.
|
/// Registers all IREEBackend passes.
|
||||||
void registerIREEBackendPasses();
|
void registerIREEBackendPasses();
|
||||||
|
|
||||||
|
/// Create a pipeline that runs all passes needed to lower the npcomp backend
|
||||||
|
/// contract to IREE's frontend contract.
|
||||||
|
void createNpcompBackendToIreeFrontendPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerLinkagePass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerLinkagePass();
|
||||||
|
|
||||||
} // namespace IREEBackend
|
} // namespace IREEBackend
|
||||||
|
|
|
@ -16,10 +16,20 @@ using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::IREEBackend;
|
using namespace mlir::NPCOMP::IREEBackend;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#define GEN_PASS_REGISTRATION
|
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||||
#include "npcomp/Backend/IREE/Passes.h.inc"
|
// the refbackrt dialect.
|
||||||
} // end namespace
|
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp module = getOperation();
|
||||||
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
|
if (func.getVisibility() == SymbolTable::Visibility::Public)
|
||||||
|
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
::registerPasses();
|
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
|
||||||
|
return std::make_unique<LowerLinkagePass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,20 +16,21 @@ using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::IREEBackend;
|
using namespace mlir::NPCOMP::IREEBackend;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
#define GEN_PASS_REGISTRATION
|
||||||
// the refbackrt dialect.
|
#include "npcomp/Backend/IREE/Passes.h.inc"
|
||||||
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
|
} // end namespace
|
||||||
void runOnOperation() override {
|
|
||||||
ModuleOp module = getOperation();
|
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
|
||||||
if (func.getVisibility() == SymbolTable::Visibility::Public)
|
|
||||||
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
void mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline(
|
||||||
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
|
OpPassManager &pm) {
|
||||||
return std::make_unique<LowerLinkagePass>();
|
pm.addPass(createLowerLinkagePass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
|
||||||
|
mlir::PassPipelineRegistration<>(
|
||||||
|
"npcomp-backend-to-iree-frontend-pipeline",
|
||||||
|
"Pipeline lowering the npcomp backend contract IR to IREE's frontend "
|
||||||
|
"contract.",
|
||||||
|
mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,6 @@ __all__ = [
|
||||||
"IreeNpcompBackend",
|
"IreeNpcompBackend",
|
||||||
]
|
]
|
||||||
|
|
||||||
PREPARE_FOR_IREE_PASSES = (
|
|
||||||
"npcomp-iree-backend-lower-linkage",
|
|
||||||
)
|
|
||||||
|
|
||||||
class IreeModuleInvoker:
|
class IreeModuleInvoker:
|
||||||
"""Wrapper around a native IREE module for calling functions."""
|
"""Wrapper around a native IREE module for calling functions."""
|
||||||
|
@ -88,7 +85,7 @@ class IreeNpcompBackend(NpcompBackend):
|
||||||
if self._debug:
|
if self._debug:
|
||||||
logging.debug("IR passed to IREE compiler backend:\n{}",
|
logging.debug("IR passed to IREE compiler backend:\n{}",
|
||||||
imported_module)
|
imported_module)
|
||||||
pipeline_str = ",".join(PREPARE_FOR_IREE_PASSES)
|
pipeline_str = "npcomp-backend-to-iree-frontend-pipeline"
|
||||||
if self._debug:
|
if self._debug:
|
||||||
logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str)
|
logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str)
|
||||||
pm = PassManager.parse(pipeline_str)
|
pm = PassManager.parse(pipeline_str)
|
||||||
|
|
Loading…
Reference in New Issue