mirror of https://github.com/llvm/torch-mlir
Add resnet inference jupyter notebook.
This takes the example from torchscript_resnet18_e2e.py and puts it into a slightly cleaned up notebook form. It's still a little rough around the edges. Areas for improvement: - Installation / setup. - API usability. Also, - Add `npcomp-backend-to-iree-frontend-pipeline` since we will be adding more stuff there. - Slight cleanups.pull/273/head
parent
f71845ea75
commit
902c2e579b
|
@ -2,6 +2,7 @@
|
|||
.vscode
|
||||
.env
|
||||
*.code-workspace
|
||||
.ipynb_checkpoints
|
||||
|
||||
/build/
|
||||
__pycache__
|
||||
|
|
|
@ -51,7 +51,7 @@ Meaning of options:
|
|||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
parser.add_argument('--filter', default='.*', help='''
|
||||
parser.add_argument('-f', '--filter', default='.*', help='''
|
||||
Regular expression specifying which tests to include in this run.
|
||||
''')
|
||||
parser.add_argument('-v', '--verbose',
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -15,21 +15,24 @@ import npcomp
|
|||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
logging.enable()
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
@ -78,6 +81,15 @@ class TestModule(torch.nn.Module):
|
|||
return self.s.forward(x)
|
||||
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
import sys
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
|
@ -88,7 +100,10 @@ class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
|
|||
class_annotator.annotateArgs(
|
||||
recursivescriptmodule._c._type(),
|
||||
["forward"],
|
||||
[None, ([-1, -1, -1, -1], torch.float32, True),],
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
],
|
||||
)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
|
@ -97,10 +112,4 @@ backend = refjit.RefjitNpcompBackend()
|
|||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
print("load image from " + image_url)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
||||
|
|
|
@ -18,6 +18,10 @@ namespace IREEBackend {
|
|||
/// Registers all IREEBackend passes.
|
||||
void registerIREEBackendPasses();
|
||||
|
||||
/// Create a pipeline that runs all passes needed to lower the npcomp backend
|
||||
/// contract to IREE's frontend contract.
|
||||
void createNpcompBackendToIreeFrontendPipeline(OpPassManager &pm);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerLinkagePass();
|
||||
|
||||
} // namespace IREEBackend
|
||||
|
|
|
@ -16,10 +16,20 @@ using namespace mlir::NPCOMP;
|
|||
using namespace mlir::NPCOMP::IREEBackend;
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Backend/IREE/Passes.h.inc"
|
||||
} // end namespace
|
||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||
// the refbackrt dialect.
|
||||
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func.getVisibility() == SymbolTable::Visibility::Public)
|
||||
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
|
||||
::registerPasses();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
|
||||
return std::make_unique<LowerLinkagePass>();
|
||||
}
|
||||
|
|
|
@ -16,20 +16,21 @@ using namespace mlir::NPCOMP;
|
|||
using namespace mlir::NPCOMP::IREEBackend;
|
||||
|
||||
namespace {
|
||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||
// the refbackrt dialect.
|
||||
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func.getVisibility() == SymbolTable::Visibility::Public)
|
||||
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Backend/IREE/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
|
||||
return std::make_unique<LowerLinkagePass>();
|
||||
void mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline(
|
||||
OpPassManager &pm) {
|
||||
pm.addPass(createLowerLinkagePass());
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
|
||||
::registerPasses();
|
||||
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"npcomp-backend-to-iree-frontend-pipeline",
|
||||
"Pipeline lowering the npcomp backend contract IR to IREE's frontend "
|
||||
"contract.",
|
||||
mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline);
|
||||
}
|
||||
|
|
|
@ -19,9 +19,6 @@ __all__ = [
|
|||
"IreeNpcompBackend",
|
||||
]
|
||||
|
||||
PREPARE_FOR_IREE_PASSES = (
|
||||
"npcomp-iree-backend-lower-linkage",
|
||||
)
|
||||
|
||||
class IreeModuleInvoker:
|
||||
"""Wrapper around a native IREE module for calling functions."""
|
||||
|
@ -88,7 +85,7 @@ class IreeNpcompBackend(NpcompBackend):
|
|||
if self._debug:
|
||||
logging.debug("IR passed to IREE compiler backend:\n{}",
|
||||
imported_module)
|
||||
pipeline_str = ",".join(PREPARE_FOR_IREE_PASSES)
|
||||
pipeline_str = "npcomp-backend-to-iree-frontend-pipeline"
|
||||
if self._debug:
|
||||
logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str)
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
|
|
Loading…
Reference in New Issue