mirror of https://github.com/llvm/torch-mlir
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
# -*- Python -*-
|
|
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import npcomp.frontends.pytorch as torch_mlir
|
|
import copy
|
|
|
|
|
|
def compare(a, b, test):
|
|
print("Computing:" + test)
|
|
err = (a.to('cpu') - b.to('cpu')).abs().max()
|
|
if (err <= 1e-5):
|
|
print("PASS! " + test + " check")
|
|
else:
|
|
print("FAILED " + test + " check")
|
|
|
|
|
|
def compare_eq(a, b, test):
|
|
print("Computing:" + test)
|
|
if (a == b):
|
|
print("PASS! " + test + " check")
|
|
else:
|
|
print("FAILED " + test + " check")
|
|
|
|
|
|
def check_fwd(model, tensor):
|
|
device = torch_mlir.mlir_device()
|
|
result = model(tensor)
|
|
device_model = copy.deepcopy(model).to(device)
|
|
device_tensor = tensor.clone().to(device)
|
|
device_result = device_model(device_tensor)
|
|
|
|
compare(result, device_result, "fwd")
|
|
return (device_model, device_result, result)
|
|
|
|
|
|
def check_ref(model, tensor):
|
|
return check_fwd(model, tensor)
|
|
|
|
|
|
def check_back(fwd_path, target, lossmodel):
|
|
device = torch_mlir.mlir_device()
|
|
(device_model, device_result, result) = fwd_path
|
|
device_target = target.clone().to(device)
|
|
ref_loss = lossmodel(result, target)
|
|
ref_loss.backward()
|
|
device_loss = lossmodel(device_result, device_target)
|
|
device_loss.backward()
|
|
|
|
compare(ref_loss, device_loss, "back")
|
|
return (device_model, device_result)
|