mirror of https://github.com/llvm/torch-mlir
Added e2e LTC tests (#916)
* Added e2e LTC Torch MLIR tests * Fix seed for reproducability * Check if computation is None before getting debug string * Updated unit tests, and added numeric tests * Print name of the model layer that fails numeric validation * Run LTC e2e test with CI/CD * Set seed in main function, instead of beginning of execution * Add comment to specify number of digits of precision * Fixed typo * Remove tests for LTC example models * Added LTC option to torchscript e2e * Implement compile and run for LTC e2e test * xfail all tests that use ops that aren't currently supportedpull/1125/head
parent
8312fa535b
commit
dfcc26556a
|
@ -128,3 +128,10 @@ jobs:
|
|||
cd $GITHUB_WORKSPACE
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=tosa -v
|
||||
|
||||
- name: Lazy Tensor Core - TorchScript end-to-end tests
|
||||
if: matrix.llvmtype == 'binary'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v
|
||||
|
|
|
@ -15,20 +15,20 @@ from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_
|
|||
|
||||
# Available test configs.
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
||||
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode']
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('-c', '--config',
|
||||
choices=config_choices,
|
||||
|
@ -40,6 +40,7 @@ Meaning of options:
|
|||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
"eager_mode": run through torch-mlir's eager mode frontend, using RefBackend for execution.
|
||||
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
||||
''')
|
||||
parser.add_argument('-f', '--filter', default='.*', help='''
|
||||
Regular expression specifying which tests to include in this run.
|
||||
|
@ -86,6 +87,9 @@ def main():
|
|||
elif args.config == 'eager_mode':
|
||||
config = EagerModeTestConfig()
|
||||
xfail_set = EAGER_MODE_XFAIL_SET
|
||||
elif args.config == 'lazy_tensor_core':
|
||||
config = LazyTensorCoreTestConfig()
|
||||
xfail_set = LTC_XFAIL_SET
|
||||
|
||||
# Find the selected tests, and emit a diagnostic if none are found.
|
||||
tests = [
|
||||
|
|
|
@ -179,3 +179,312 @@ TOSA_PASS_SET = {
|
|||
"ArgmaxModule_with_dim",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeDtypeFloatModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
"ArangeFloatModule_basic",
|
||||
"ArangeIntModule_basic",
|
||||
"ArangeNegativeStartFloatModule_basic",
|
||||
"ArangeNegativeStartIntModule_basic",
|
||||
"ArangeStartFloatModule_basic",
|
||||
"ArangeStartIntModule_basic",
|
||||
"ArangeStartNegativeStepFloatModule_basic",
|
||||
"ArangeStartNegativeStepIntModule_basic",
|
||||
"ArangeStartStepFloatModule_basic",
|
||||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"AvgPool2dCeilModeTrueModule_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"AvgPool2dFloatModule_basic",
|
||||
"AvgPool2dIntModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliModule_basic",
|
||||
"BernoulliOnesModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"BernoulliZerosModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
"BincountStaticSizeModule_basic",
|
||||
"BoolFloatConstantModule_basic",
|
||||
"BoolFloatFalseModule_basic",
|
||||
"BoolFloatTrueModule_basic",
|
||||
"BoolIntConstantModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
|
||||
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
|
||||
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
|
||||
"ElementwiseAtenLogicalOrOpModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpRandomModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereSelfModule_basic",
|
||||
"EmptyLikeMemoryFormatModule_basic",
|
||||
"EmptyLikeModule_defaultDtype",
|
||||
"EmptyLikeModule_falsePinMemory",
|
||||
"EmptyLikeModule_float",
|
||||
"EmptyLikeModule_int",
|
||||
"EmptyModule_contiguous",
|
||||
"EmptyModule_defaultDtype",
|
||||
"EmptyModule_falsePinMemory",
|
||||
"EmptyModule_float",
|
||||
"EmptyModule_int",
|
||||
"EqIntModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32_basic",
|
||||
"Fill_TensorFloat64WithFloat64_basic",
|
||||
"Fill_TensorFloat64WithInt64_basic",
|
||||
"FullLikeModuleDefaultDtype_basic",
|
||||
"FullLikeModuleFalsePinMemory_basic",
|
||||
"FullLikeModuleFloat2D_basic",
|
||||
"FullLikeModuleFloat3DStatic_basic",
|
||||
"FullLikeModuleFloat3D_basic",
|
||||
"FullLikeModuleInt2DStatic_basic",
|
||||
"FullLikeModuleInt2D_basic",
|
||||
"FullLikeModuleInt3D_basic",
|
||||
"FullModuleDefaultDtype_basic",
|
||||
"FullModuleFalsePinMemory_basic",
|
||||
"FullModuleFloat2D_basic",
|
||||
"FullModuleFloat3D_basic",
|
||||
"FullModuleInt2D_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"HBC_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
"IndexPut1DIntAccumulateModule_basic",
|
||||
"IndexPut1DIntNonAccumulateModule_basic",
|
||||
"IndexPut2DFloatAccumulateModule_basic",
|
||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
||||
"IndexPut2DIntAccumulateModule_basic",
|
||||
"IndexPut2DIntNonAccumulateModule_basic",
|
||||
"IndexPut3DFloatAccumulateModule_basic",
|
||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
||||
"IndexPut3DIntAccumulateModule_basic",
|
||||
"IndexPut3DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicInputSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorModule_basic",
|
||||
"MaskedFillScalarDefaultModule_basic",
|
||||
"MaskedFillScalarFloatValueModule_basic",
|
||||
"MaskedFillScalarIntValueModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_matvec",
|
||||
"Matmul_vecmat",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"MaxPool2dWith3dInputModule_basic",
|
||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
|
||||
"MaxPool2dWithIndicesCeilModeTrueModule_basic",
|
||||
"MaxPool2dWithIndicesFullSizeKernelModule_basic",
|
||||
"MaxPool2dWithIndicesModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
|
||||
"MaxPool2dWithIndicesStaticModule_basic",
|
||||
"MaxPool2dWithIndicesWith3dInputModule_basic",
|
||||
"MeanDimAllReduceKeepdimModule_basic",
|
||||
"MeanDimAllReduceModule_basic",
|
||||
"MeanDimDtypeModule_basic",
|
||||
"MeanDimKeepdimModule_basic",
|
||||
"MeanDimModule_basic",
|
||||
"MeanDimNegativeModule_basic",
|
||||
"MeanDtypeModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
"MulIntModule_basic",
|
||||
"NativeBatchNorm1DModule_basic",
|
||||
"NativeBatchNorm2DModule_basic",
|
||||
"NativeBatchNorm3DModule_basic",
|
||||
"NativeBatchNormNoneWeightModule_basic",
|
||||
"NativeLayerNormDynamicModule_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NewEmptyModuleDefaultDtype_basic",
|
||||
"NewEmptyModuleFalsePinMemory_basic",
|
||||
"NewEmptyModuleFloat2D_basic",
|
||||
"NewEmptyModuleFloat3D_basic",
|
||||
"NewEmptyModuleInt2D_basic",
|
||||
"NewEmptyModuleInt3D_basic",
|
||||
"NewEmptyModuleLayoutIntDtype_basic",
|
||||
"NewEmptyModuleNonDefaultFloatDtype_basic",
|
||||
"NewEmptyModuleNonDefaultIntDtype_basic",
|
||||
"NewOnesModuleDefaultDtype_basic",
|
||||
"NewOnesModuleFalsePinMemory_basic",
|
||||
"NewOnesModuleFloat2D_basic",
|
||||
"NewOnesModuleFloat3D_basic",
|
||||
"NewOnesModuleInt2D_basic",
|
||||
"NewOnesModuleInt3D_basic",
|
||||
"NewZerosModuleDefaultDtype_basic",
|
||||
"NewZerosModuleFalsePinMemory_basic",
|
||||
"NewZerosModuleFloat2D_basic",
|
||||
"NewZerosModuleFloat3D_basic",
|
||||
"NewZerosModuleInt2D_basic",
|
||||
"NewZerosModuleInt3D_basic",
|
||||
"NllLossModuleBackward1DMeanWeight_basic",
|
||||
"NllLossModuleBackward1DMean_basic",
|
||||
"NllLossModuleBackward1DSumWeight_basic",
|
||||
"NllLossModuleBackward1DSum_basic",
|
||||
"NllLossModuleBackward1DWeight_basic",
|
||||
"NllLossModuleBackward1D_basic",
|
||||
"NllLossModuleBackwardMeanWeight_basic",
|
||||
"NllLossModuleBackwardMean_basic",
|
||||
"NllLossModuleBackwardSumWeight_basic",
|
||||
"NllLossModuleBackwardSum_basic",
|
||||
"NllLossModuleBackwardWeight_basic",
|
||||
"NllLossModuleBackward_basic",
|
||||
"NllLossModuleBackward_ignore_index",
|
||||
"NllLossModule_1D_basic",
|
||||
"NllLossModule_basic",
|
||||
"NllLossModule_ignore_index_out_of_bounds_basic",
|
||||
"NllLossModule_mean_basic",
|
||||
"NllLossModule_sum_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"OnesLikeModule_defaultDtype",
|
||||
"OnesLikeModule_falsePinMemory",
|
||||
"OnesLikeModule_float",
|
||||
"OnesLikeModule_int",
|
||||
"OnesModuleDefaultDtype_basic",
|
||||
"OnesModuleFalsePinMemory_basic",
|
||||
"OnesModuleFloat_basic",
|
||||
"OnesModuleInt_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"RandLikeDtypeModule_basic",
|
||||
"RandLikeModule_basic",
|
||||
"ReduceMaxKeepDimReturnBoth_basic",
|
||||
"ReduceMaxNegativeDim_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"SelectIntModule_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceSingleIdxModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"SqrtIntConstantModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TModuleRank0_basic",
|
||||
"TModuleRank1_basic",
|
||||
"TableBatchEmbeddingModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"TensorToFloat_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToInt_basic",
|
||||
"TensorsConcatModule_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"Threshold1dFloatModule_basic",
|
||||
"Threshold1dIntI32Module_basic",
|
||||
"Threshold1dIntModule_basic",
|
||||
"Threshold2dFloatModule_basic",
|
||||
"Threshold2dIntModule_basic",
|
||||
"Threshold3dFloatModule_basic",
|
||||
"Threshold3dIntModule_basic",
|
||||
"ThresholdBackward1dFloatModule_basic",
|
||||
"ThresholdBackward1dIntModule_basic",
|
||||
"ThresholdBackward1dMixedModule_basic",
|
||||
"ThresholdBackward2dFloatModule_basic",
|
||||
"ThresholdBackward2dIntModule_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"ThresholdBackward3dFloatModule_basic",
|
||||
"ThresholdBackward3dIntModule_basic",
|
||||
"ThresholdBackward3dMixedModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformStaticModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarUnbiasedModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ZeroFloat32Module_basic",
|
||||
"ZeroInt32Module_basic",
|
||||
"ZeroInt64Module_basic",
|
||||
"ZerosLikeModule_defaultDtype",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
"ZerosLikeModule_float",
|
||||
"ZerosLikeModule_int",
|
||||
"ZerosModuleDefaultDtype_basic",
|
||||
"ZerosModuleFalsePinMemory_basic",
|
||||
"ZerosModuleFloat2D_basic",
|
||||
"ZerosModuleFloat3D_basic",
|
||||
"ZerosModuleInt2D_basic",
|
||||
"ZerosModuleInt3D_basic",
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/utils/debug.h>
|
||||
#include <torch_mlir/csrc/utils/exception.h>
|
||||
|
@ -60,10 +60,13 @@ public:
|
|||
|
||||
// Vendor backend specific lowering can be exec here before returning.
|
||||
for (const auto &instance : instances) {
|
||||
std::cout << "Instance received at Compile: \n"
|
||||
<< GetComputationBackendText(instance) << std::endl;
|
||||
// Store computation instance for external access after compilation.
|
||||
GetLatestComputation() = instance;
|
||||
}
|
||||
|
||||
std::cout << "Received " << instances.size()
|
||||
<< " computation instances at Compile!" << std::endl;
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
|
@ -133,9 +136,13 @@ public:
|
|||
* */
|
||||
std::string
|
||||
GetComputationBackendText(const ComputationPtr computation) const override {
|
||||
auto mlir_computation =
|
||||
static_cast<TorchMlirComputation *>(computation.get());
|
||||
return mlir_computation->to_string();
|
||||
// Store computation instance for external access after compilation.
|
||||
// We do this in GetComputationBackendText since there may be instances
|
||||
// where a ComputationPtr does not pass through Compile (e.g. when using
|
||||
// DumpUtil::ToBackend.)
|
||||
GetLatestComputation() = computation;
|
||||
|
||||
return computation->to_string();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -154,5 +161,11 @@ void InitExampleMlirBackend() {
|
|||
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
|
||||
}
|
||||
|
||||
ComputationPtr &GetLatestComputation() {
|
||||
// Store the computation from the most recent compile.
|
||||
static ComputationPtr computation;
|
||||
return computation;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -23,5 +23,7 @@ torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();
|
|||
|
||||
void InitExampleMlirBackend();
|
||||
|
||||
ComputationPtr &GetLatestComputation();
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#include "torch/csrc/jit/python/pybind.h"
|
||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
@ -61,7 +63,16 @@ void Shutdown() {
|
|||
} // anonymous namespace
|
||||
|
||||
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
|
||||
py::class_<torch::lazy::TorchMlirComputation>(m, "TorchMlirComputation")
|
||||
.def("to_string", &torch::lazy::TorchMlirComputation::to_string)
|
||||
.def("debug_string", &torch::lazy::TorchMlirComputation::debug_string);
|
||||
|
||||
m.doc() = ("pybind11 for example MLIR LTC backend.");
|
||||
m.def("get_latest_computation", []() {
|
||||
auto computation = static_cast<torch::lazy::TorchMlirComputation *>(
|
||||
torch::lazy::GetLatestComputation().get());
|
||||
return py::cast(computation);
|
||||
});
|
||||
m.def("_initialize", []() {
|
||||
NoGilSection gil;
|
||||
Initialize();
|
||||
|
|
|
@ -14,13 +14,19 @@ Based on LTC code samples by ramiro050
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
import torch
|
||||
import torch._C
|
||||
import torch._lazy
|
||||
import torch._lazy.ts_backend
|
||||
from datasets import load_dataset
|
||||
from datasets.dataset_dict import DatasetDict
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BertForSequenceClassification, \
|
||||
BertConfig, BertTokenizer, AdamW, get_scheduler
|
||||
from typing import List
|
||||
|
||||
|
||||
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
||||
|
@ -42,8 +48,7 @@ def train(model: BertForSequenceClassification,
|
|||
num_epochs: int,
|
||||
num_training_steps: int,
|
||||
train_dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
do_mark_step: bool) -> List[torch.Tensor]:
|
||||
device: torch.device) -> List[torch.Tensor]:
|
||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
|
@ -63,31 +68,21 @@ def train(model: BertForSequenceClassification,
|
|||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if do_mark_step and 'lazy' in str(model.device):
|
||||
if 'lazy' in str(model.device):
|
||||
print("Calling Mark Step")
|
||||
torch._lazy.mark_step()
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
def main(device, lower_only, full_size):
|
||||
if device in ("TS", "MLIR_EXAMPLE"):
|
||||
import torch._lazy
|
||||
def main(device='lazy', full_size=False):
|
||||
"""
|
||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||
|
||||
if device == "TS":
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif device == "MLIR_EXAMPLE":
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
|
||||
ltc_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
print("Initialized backend")
|
||||
else:
|
||||
device = device.lower()
|
||||
:param device: name of device to load tensors to
|
||||
:param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
|
||||
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
|
||||
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
|
||||
|
@ -117,22 +112,20 @@ def main(device, lower_only, full_size):
|
|||
|
||||
num_epochs = 3
|
||||
num_training_steps = num_epochs * len(train_dataloader)
|
||||
losses = train(model, num_epochs,
|
||||
num_training_steps, train_dataloader, device, not lower_only)
|
||||
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
||||
|
||||
# Get debug information from LTC
|
||||
if 'ltc_backend' in sys.modules:
|
||||
computation = ltc_backend.get_latest_computation()
|
||||
if computation:
|
||||
print(computation.debug_string())
|
||||
|
||||
if lower_only:
|
||||
print('\nJIT Graph:')
|
||||
import torch._C
|
||||
graph_str = torch._C._lazy._get_tensors_backend([losses[0]])
|
||||
print(graph_str)
|
||||
else:
|
||||
# Execute computation
|
||||
print('Loss: ', losses)
|
||||
|
||||
return model, losses
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
|
@ -142,13 +135,6 @@ if __name__ == "__main__":
|
|||
default="MLIR_EXAMPLE",
|
||||
help="The device type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--lower_only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Only get backend printout -- do not execute computation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--full_size",
|
||||
|
@ -157,4 +143,17 @@ if __name__ == "__main__":
|
|||
help="Use full sized BERT model instead of one with smaller parameterization",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.device, args.lower_only, args.full_size)
|
||||
|
||||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||
if args.device == "TS":
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif args.device == "MLIR_EXAMPLE":
|
||||
ltc_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
print("Initialized backend")
|
||||
else:
|
||||
device = args.device.lower()
|
||||
|
||||
main(device, args.full_size)
|
||||
|
|
|
@ -6,30 +6,22 @@
|
|||
Example use of the example Torch MLIR LTC backend.
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
import torch
|
||||
import torch._lazy
|
||||
import torch._lazy.ts_backend
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def main(device):
|
||||
import torch
|
||||
def main(device='lazy'):
|
||||
"""
|
||||
Load model to specified device. Ensure that any backends have been initialized by this point.
|
||||
|
||||
if device in ("TS", "MLIR_EXAMPLE"):
|
||||
import torch._lazy
|
||||
|
||||
if device == "TS":
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif device == "MLIR_EXAMPLE":
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
|
||||
ltc_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
print("Initialized backend")
|
||||
else:
|
||||
device = device.lower()
|
||||
:param device: name of device to load tensors to
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
|
||||
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
|
||||
assert inputs.device.type == device
|
||||
|
@ -57,24 +49,35 @@ def main(device):
|
|||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
num_epochs = 3
|
||||
losses = []
|
||||
for _ in range(num_epochs):
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
losses.append(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if device == "lazy":
|
||||
print("Calling Mark Step")
|
||||
torch._lazy.mark_step()
|
||||
|
||||
print()
|
||||
print(loss)
|
||||
# Get debug information from LTC
|
||||
if 'ltc_backend' in sys.modules:
|
||||
computation = ltc_backend.get_latest_computation()
|
||||
if computation:
|
||||
print(computation.debug_string())
|
||||
|
||||
print(losses)
|
||||
|
||||
return model, losses
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(0)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
|
@ -85,4 +88,17 @@ if __name__ == "__main__":
|
|||
help="The device type",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.device)
|
||||
|
||||
if args.device in ("TS", "MLIR_EXAMPLE"):
|
||||
if args.device == "TS":
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif args.device == "MLIR_EXAMPLE":
|
||||
ltc_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
print("Initialized backend")
|
||||
else:
|
||||
device = args.device.lower()
|
||||
|
||||
main(device)
|
||||
|
|
|
@ -323,24 +323,14 @@ std::shared_ptr<torch::jit::Graph> TorchMlirComputation::graph() const {
|
|||
|
||||
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
|
||||
|
||||
const std::string TorchMlirComputation::to_string() const {
|
||||
// Since we use the C-MLIR API, we need to use a callback to print.
|
||||
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
|
||||
// user_data is a void ptr to some data structure of our choice -- in this
|
||||
// case, the string stream where we'll be accumulating the strings.
|
||||
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
|
||||
*ss_ptr << std::string(part.data, part.length);
|
||||
};
|
||||
|
||||
const std::string TorchMlirComputation::debug_string() const {
|
||||
std::stringstream ss;
|
||||
|
||||
// JIT Graph
|
||||
ss << "JIT Graph: \n" << graph_->toString() << "\n\n";
|
||||
|
||||
// MLIR
|
||||
ss << "MLIR: \n";
|
||||
mlirOperationPrint(func_op_, print_callback, &ss);
|
||||
ss << "\n";
|
||||
ss << "MLIR: \n" << to_string() << "\n";
|
||||
|
||||
// Input/Output Mapping
|
||||
ss << "Input/Output Alias Mapping: \n";
|
||||
|
@ -356,5 +346,18 @@ const std::string TorchMlirComputation::to_string() const {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
const std::string TorchMlirComputation::to_string() const {
|
||||
// Since we use the C-MLIR API, we need to use a callback to print.
|
||||
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
|
||||
// user_data is a void ptr to some data structure of our choice -- in this
|
||||
// case, the string stream where we'll be accumulating the strings.
|
||||
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
|
||||
*ss_ptr << std::string(part.data, part.length);
|
||||
};
|
||||
std::stringstream ss;
|
||||
mlirOperationPrint(func_op_, print_callback, &ss);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -135,6 +135,8 @@ public:
|
|||
|
||||
MlirOperation func_op() const;
|
||||
|
||||
const std::string debug_string() const;
|
||||
|
||||
const std::string to_string() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from .lazy_tensor_core import LazyTensorCoreTestConfig
|
||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
import torch
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
class LazyTensorCoreTestConfig(TestConfig):
|
||||
"""TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
ltc_backend._initialize()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program.to('lazy')
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
|
||||
for item in trace:
|
||||
# We need to move all the inputs to the lazy device before running in LTC.
|
||||
lazy_inputs = [arg.to('lazy') for arg in item.inputs]
|
||||
output = getattr(artifact, item.symbol)(*lazy_inputs)
|
||||
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output.to('cpu')))
|
||||
|
||||
return result
|
Loading…
Reference in New Issue