torch-mlir/examples/ltc_backend_mnist.py

106 lines
2.7 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Example use of the example Torch MLIR LTC backend.
"""
import argparse
import sys
import torch
import torch._lazy
import torch.nn.functional as F
def main(device='lazy'):
"""
Load model to specified device. Ensure that any backends have been initialized by this point.
:param device: name of device to load tensors to
"""
torch.manual_seed(0)
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
assert inputs.device.type == device
targets = torch.tensor([3], dtype=torch.int64, device=device)
assert targets.device.type == device
print("Initialized data")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
return out
model = Model().to(device)
model.train()
assert all(p.device.type == device for p in model.parameters())
print("Initialized model")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_epochs = 3
losses = []
for _ in range(num_epochs):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
losses.append(loss)
optimizer.step()
if device == "lazy":
print("Calling Mark Step")
torch._lazy.mark_step()
# Get debug information from LTC
if 'ltc_backend' in sys.modules:
computation = ltc_backend.get_latest_computation()
if computation:
print(computation.debug_string())
print(losses)
return model, losses
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--device",
type=str.upper,
choices=["CPU", "TS", "MLIR_EXAMPLE"],
default="MLIR_EXAMPLE",
help="The device type",
)
args = parser.parse_args()
if args.device in ("TS", "MLIR_EXAMPLE"):
if args.device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif args.device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = args.device.lower()
main(device)