Added e2e LTC tests (#916)

* Added e2e LTC Torch MLIR tests

* Fix seed for reproducability

* Check if computation is None before getting debug string

* Updated unit tests, and added numeric tests

* Print name of the model layer that fails numeric validation

* Run LTC e2e test with CI/CD

* Set seed in main function, instead of beginning of execution

* Add comment to specify number of digits of precision

* Fixed typo

* Remove tests for LTC example models

* Added LTC option to torchscript e2e

* Implement compile and run for LTC e2e test

* xfail all tests that use ops that aren't currently supported
pull/1125/head
Henry Tu 2022-06-09 15:56:01 -04:00 committed by Henry Tu
parent 8312fa535b
commit dfcc26556a
12 changed files with 498 additions and 97 deletions

View File

@ -128,3 +128,10 @@ jobs:
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=tosa -v
- name: Lazy Tensor Core - TorchScript end-to-end tests
if: matrix.llvmtype == 'binary'
run: |
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v

View File

@ -15,20 +15,20 @@ from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_
# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
)
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
@ -40,6 +40,7 @@ Meaning of options:
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
"eager_mode": run through torch-mlir's eager mode frontend, using RefBackend for execution.
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
''')
parser.add_argument('-f', '--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
@ -86,6 +87,9 @@ def main():
elif args.config == 'eager_mode':
config = EagerModeTestConfig()
xfail_set = EAGER_MODE_XFAIL_SET
elif args.config == 'lazy_tensor_core':
config = LazyTensorCoreTestConfig()
xfail_set = LTC_XFAIL_SET
# Find the selected tests, and emit a diagnostic if none are found.
tests = [

View File

@ -179,3 +179,312 @@ TOSA_PASS_SET = {
"ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic",
}
LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeIntModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeNegativeStartIntModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
"BoolFloatConstantModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntConstantModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DropoutTrainModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
"ElementwiseAtenLogicalOrOpModule_basic",
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
"ElementwiseAtenLogicalOrOpRandomModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereSelfModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EqIntModule_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64_basic",
"FullLikeModuleDefaultDtype_basic",
"FullLikeModuleFalsePinMemory_basic",
"FullLikeModuleFloat2D_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullLikeModuleFloat3D_basic",
"FullLikeModuleInt2DStatic_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleFloat2D_basic",
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorModule_basic",
"MaskedFillScalarDefaultModule_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarIntValueModule_basic",
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
"MaxPool2dWithIndicesCeilModeTrueModule_basic",
"MaxPool2dWithIndicesFullSizeKernelModule_basic",
"MaxPool2dWithIndicesModule_basic",
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimKeepdimModule_basic",
"MeanDimModule_basic",
"MeanDimNegativeModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",
"MobilenetV3Module_basic",
"MulIntModule_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeLayerNormDynamicModule_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
"NewEmptyModuleFloat3D_basic",
"NewEmptyModuleInt2D_basic",
"NewEmptyModuleInt3D_basic",
"NewEmptyModuleLayoutIntDtype_basic",
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleFalsePinMemory_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
"NllLossModuleBackward1DSum_basic",
"NllLossModuleBackward1DWeight_basic",
"NllLossModuleBackward1D_basic",
"NllLossModuleBackwardMeanWeight_basic",
"NllLossModuleBackwardMean_basic",
"NllLossModuleBackwardSumWeight_basic",
"NllLossModuleBackwardSum_basic",
"NllLossModuleBackwardWeight_basic",
"NllLossModuleBackward_basic",
"NllLossModuleBackward_ignore_index",
"NllLossModule_1D_basic",
"NllLossModule_basic",
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"OnesModuleDefaultDtype_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleFloat_basic",
"OnesModuleInt_basic",
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SelectIntModule_basic",
"SliceEndSleStartModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceSingleIdxModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceWholeTensorModule_basic",
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"StdBiasedModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TableBatchEmbeddingModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
"TensorToFloat_basic",
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"TensorsConcatModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic",
"Threshold1dFloatModule_basic",
"Threshold1dIntI32Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dFloatModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dFloatModule_basic",
"Threshold3dIntModule_basic",
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
"ThresholdBackward2dFloatModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward2dMixedModule_basic",
"ThresholdBackward3dFloatModule_basic",
"ThresholdBackward3dIntModule_basic",
"ThresholdBackward3dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"UniformModule_basic",
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"VarBiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleFalsePinMemory_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
}

View File

@ -12,8 +12,8 @@
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>
#include <torch_mlir/csrc/utils/exception.h>
@ -60,10 +60,13 @@ public:
// Vendor backend specific lowering can be exec here before returning.
for (const auto &instance : instances) {
std::cout << "Instance received at Compile: \n"
<< GetComputationBackendText(instance) << std::endl;
// Store computation instance for external access after compilation.
GetLatestComputation() = instance;
}
std::cout << "Received " << instances.size()
<< " computation instances at Compile!" << std::endl;
return instances;
}
@ -133,9 +136,13 @@ public:
* */
std::string
GetComputationBackendText(const ComputationPtr computation) const override {
auto mlir_computation =
static_cast<TorchMlirComputation *>(computation.get());
return mlir_computation->to_string();
// Store computation instance for external access after compilation.
// We do this in GetComputationBackendText since there may be instances
// where a ComputationPtr does not pass through Compile (e.g. when using
// DumpUtil::ToBackend.)
GetLatestComputation() = computation;
return computation->to_string();
}
private:
@ -154,5 +161,11 @@ void InitExampleMlirBackend() {
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
}
ComputationPtr &GetLatestComputation() {
// Store the computation from the most recent compile.
static ComputationPtr computation;
return computation;
}
} // namespace lazy
} // namespace torch

View File

@ -23,5 +23,7 @@ torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();
void InitExampleMlirBackend();
ComputationPtr &GetLatestComputation();
} // namespace lazy
} // namespace torch

View File

@ -10,6 +10,8 @@
#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <exception>
#include <iostream>
#include <string>
@ -61,7 +63,16 @@ void Shutdown() {
} // anonymous namespace
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
py::class_<torch::lazy::TorchMlirComputation>(m, "TorchMlirComputation")
.def("to_string", &torch::lazy::TorchMlirComputation::to_string)
.def("debug_string", &torch::lazy::TorchMlirComputation::debug_string);
m.doc() = ("pybind11 for example MLIR LTC backend.");
m.def("get_latest_computation", []() {
auto computation = static_cast<torch::lazy::TorchMlirComputation *>(
torch::lazy::GetLatestComputation().get());
return py::cast(computation);
});
m.def("_initialize", []() {
NoGilSection gil;
Initialize();

View File

@ -14,13 +14,19 @@ Based on LTC code samples by ramiro050
"""
import argparse
import sys
from typing import List
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
import torch
import torch._C
import torch._lazy
import torch._lazy.ts_backend
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, \
BertConfig, BertTokenizer, AdamW, get_scheduler
from typing import List
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
@ -42,8 +48,7 @@ def train(model: BertForSequenceClassification,
num_epochs: int,
num_training_steps: int,
train_dataloader: DataLoader,
device: torch.device,
do_mark_step: bool) -> List[torch.Tensor]:
device: torch.device) -> List[torch.Tensor]:
optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
num_warmup_steps=0,
@ -63,31 +68,21 @@ def train(model: BertForSequenceClassification,
lr_scheduler.step()
optimizer.zero_grad()
if do_mark_step and 'lazy' in str(model.device):
if 'lazy' in str(model.device):
print("Calling Mark Step")
torch._lazy.mark_step()
return losses
def main(device, lower_only, full_size):
if device in ("TS", "MLIR_EXAMPLE"):
import torch._lazy
def main(device='lazy', full_size=False):
"""
Load model to specified device. Ensure that any backends have been initialized by this point.
if device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = device.lower()
:param device: name of device to load tensors to
:param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant
"""
torch.manual_seed(0)
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
@ -117,22 +112,20 @@ def main(device, lower_only, full_size):
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
losses = train(model, num_epochs,
num_training_steps, train_dataloader, device, not lower_only)
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
# Get debug information from LTC
if 'ltc_backend' in sys.modules:
computation = ltc_backend.get_latest_computation()
if computation:
print(computation.debug_string())
if lower_only:
print('\nJIT Graph:')
import torch._C
graph_str = torch._C._lazy._get_tensors_backend([losses[0]])
print(graph_str)
else:
# Execute computation
print('Loss: ', losses)
return model, losses
if __name__ == "__main__":
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
@ -142,13 +135,6 @@ if __name__ == "__main__":
default="MLIR_EXAMPLE",
help="The device type",
)
parser.add_argument(
"-l",
"--lower_only",
action='store_true',
default=False,
help="Only get backend printout -- do not execute computation",
)
parser.add_argument(
"-f",
"--full_size",
@ -157,4 +143,17 @@ if __name__ == "__main__":
help="Use full sized BERT model instead of one with smaller parameterization",
)
args = parser.parse_args()
main(args.device, args.lower_only, args.full_size)
if args.device in ("TS", "MLIR_EXAMPLE"):
if args.device == "TS":
torch._lazy.ts_backend.init()
elif args.device == "MLIR_EXAMPLE":
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = args.device.lower()
main(device, args.full_size)

View File

@ -6,30 +6,22 @@
Example use of the example Torch MLIR LTC backend.
"""
import argparse
import sys
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
import torch
import torch._lazy
import torch._lazy.ts_backend
import torch.nn.functional as F
def main(device):
import torch
def main(device='lazy'):
"""
Load model to specified device. Ensure that any backends have been initialized by this point.
if device in ("TS", "MLIR_EXAMPLE"):
import torch._lazy
if device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = device.lower()
:param device: name of device to load tensors to
"""
torch.manual_seed(0)
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
assert inputs.device.type == device
@ -57,24 +49,35 @@ def main(device):
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_epochs = 3
losses = []
for _ in range(num_epochs):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
losses.append(loss)
optimizer.step()
if device == "lazy":
print("Calling Mark Step")
torch._lazy.mark_step()
print()
print(loss)
# Get debug information from LTC
if 'ltc_backend' in sys.modules:
computation = ltc_backend.get_latest_computation()
if computation:
print(computation.debug_string())
print(losses)
return model, losses
if __name__ == "__main__":
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
@ -85,4 +88,17 @@ if __name__ == "__main__":
help="The device type",
)
args = parser.parse_args()
main(args.device)
if args.device in ("TS", "MLIR_EXAMPLE"):
if args.device == "TS":
torch._lazy.ts_backend.init()
elif args.device == "MLIR_EXAMPLE":
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = args.device.lower()
main(device)

View File

@ -323,24 +323,14 @@ std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
const std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this
// case, the string stream where we'll be accumulating the strings.
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
*ss_ptr << std::string(part.data, part.length);
};
const std::string TorchMlirComputation::debug_string() const {
std::stringstream ss;
// JIT Graph
ss << "JIT Graph: \n" << graph_->toString() << "\n\n";
// MLIR
ss << "MLIR: \n";
mlirOperationPrint(func_op_, print_callback, &ss);
ss << "\n";
ss << "MLIR: \n" << to_string() << "\n";
// Input/Output Mapping
ss << "Input/Output Alias Mapping: \n";
@ -356,5 +346,18 @@ const std::string TorchMlirComputation::to_string() const {
return ss.str();
}
const std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this
// case, the string stream where we'll be accumulating the strings.
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
*ss_ptr << std::string(part.data, part.length);
};
std::stringstream ss;
mlirOperationPrint(func_op_, print_callback, &ss);
return ss.str();
}
} // namespace lazy
} // namespace torch

View File

@ -135,6 +135,8 @@ public:
MlirOperation func_op() const;
const std::string debug_string() const;
const std::string to_string() const;
private:

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from .lazy_tensor_core import LazyTensorCoreTestConfig
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig

View File

@ -0,0 +1,34 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
import torch
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
class LazyTensorCoreTestConfig(TestConfig):
"""TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR"""
def __init__(self):
super().__init__()
ltc_backend._initialize()
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
return program.to('lazy')
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
# We need to move all the inputs to the lazy device before running in LTC.
lazy_inputs = [arg.to('lazy') for arg in item.inputs]
output = getattr(artifact, item.symbol)(*lazy_inputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output.to('cpu')))
return result