2022-04-15 00:53:00 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
"""
|
|
|
|
Example use of the example Torch MLIR LTC backend.
|
|
|
|
"""
|
|
|
|
import argparse
|
2022-06-10 03:56:01 +08:00
|
|
|
import sys
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
import torch
|
|
|
|
import torch._lazy
|
2022-04-15 00:53:00 +08:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
def main(device="lazy"):
|
2022-06-10 03:56:01 +08:00
|
|
|
"""
|
|
|
|
Load model to specified device. Ensure that any backends have been initialized by this point.
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
:param device: name of device to load tensors to
|
|
|
|
"""
|
|
|
|
torch.manual_seed(0)
|
2022-04-15 00:53:00 +08:00
|
|
|
|
|
|
|
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
|
|
|
|
assert inputs.device.type == device
|
|
|
|
|
|
|
|
targets = torch.tensor([3], dtype=torch.int64, device=device)
|
|
|
|
assert targets.device.type == device
|
|
|
|
|
|
|
|
print("Initialized data")
|
|
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
2022-05-27 03:53:15 +08:00
|
|
|
self.fc1 = torch.nn.Linear(5, 10)
|
2022-04-15 00:53:00 +08:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
out = self.fc1(x)
|
|
|
|
out = F.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
model = Model().to(device)
|
|
|
|
model.train()
|
|
|
|
assert all(p.device.type == device for p in model.parameters())
|
|
|
|
|
|
|
|
print("Initialized model")
|
|
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
num_epochs = 3
|
|
|
|
losses = []
|
|
|
|
for _ in range(num_epochs):
|
|
|
|
optimizer.zero_grad()
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
outputs = model(inputs)
|
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
loss.backward()
|
|
|
|
losses.append(loss)
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
optimizer.step()
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
if device == "lazy":
|
|
|
|
print("Calling Mark Step")
|
|
|
|
torch._lazy.mark_step()
|
2022-04-15 00:53:00 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
# Get debug information from LTC
|
2024-04-28 05:16:31 +08:00
|
|
|
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
2022-07-13 03:56:52 +08:00
|
|
|
computation = lazy_backend.get_latest_computation()
|
2022-06-10 03:56:01 +08:00
|
|
|
if computation:
|
|
|
|
print(computation.debug_string())
|
2022-06-08 02:38:50 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
print(losses)
|
|
|
|
|
|
|
|
return model, losses
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-04-15 00:53:00 +08:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
"-d",
|
|
|
|
"--device",
|
|
|
|
type=str.upper,
|
|
|
|
choices=["CPU", "TS", "MLIR_EXAMPLE"],
|
|
|
|
default="MLIR_EXAMPLE",
|
|
|
|
help="The device type",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
2022-06-10 03:56:01 +08:00
|
|
|
|
|
|
|
if args.device in ("TS", "MLIR_EXAMPLE"):
|
|
|
|
if args.device == "TS":
|
2022-06-15 00:09:55 +08:00
|
|
|
import torch._lazy.ts_backend
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
torch._lazy.ts_backend.init()
|
|
|
|
|
|
|
|
elif args.device == "MLIR_EXAMPLE":
|
2022-08-26 06:25:01 +08:00
|
|
|
import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend
|
2022-07-01 03:19:05 +08:00
|
|
|
|
2022-07-13 03:56:52 +08:00
|
|
|
lazy_backend._initialize()
|
2022-06-10 03:56:01 +08:00
|
|
|
|
|
|
|
device = "lazy"
|
|
|
|
print("Initialized backend")
|
|
|
|
else:
|
|
|
|
device = args.device.lower()
|
|
|
|
|
|
|
|
main(device)
|