torch-mlir/examples/ltc_backend_mnist.py

106 lines
2.7 KiB
Python
Raw Permalink Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Example use of the example Torch MLIR LTC backend.
"""
import argparse
import sys
import torch
import torch._lazy
import torch.nn.functional as F
def main(device='lazy'):
"""
Load model to specified device. Ensure that any backends have been initialized by this point.
:param device: name of device to load tensors to
"""
torch.manual_seed(0)
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
assert inputs.device.type == device
targets = torch.tensor([3], dtype=torch.int64, device=device)
assert targets.device.type == device
print("Initialized data")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
return out
model = Model().to(device)
model.train()
assert all(p.device.type == device for p in model.parameters())
print("Initialized model")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_epochs = 3
losses = []
for _ in range(num_epochs):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
losses.append(loss)
optimizer.step()
if device == "lazy":
print("Calling Mark Step")
torch._lazy.mark_step()
# Get debug information from LTC
if 'torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND' in sys.modules:
computation = lazy_backend.get_latest_computation()
if computation:
print(computation.debug_string())
E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-06-08 02:38:50 +08:00
print(losses)
return model, losses
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--device",
type=str.upper,
choices=["CPU", "TS", "MLIR_EXAMPLE"],
default="MLIR_EXAMPLE",
help="The device type",
)
args = parser.parse_args()
if args.device in ("TS", "MLIR_EXAMPLE"):
if args.device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif args.device == "MLIR_EXAMPLE":
import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend
lazy_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = args.device.lower()
main(device)