mirror of https://github.com/llvm/torch-mlir
parent
8cad02f87e
commit
2374098d71
|
@ -167,6 +167,12 @@ jobs:
|
|||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=eager_mode -v
|
||||
|
||||
- name: Run mhlo e2e integration tests
|
||||
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
|
||||
run: |
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=mhlo -v
|
||||
|
||||
- name: Run tosa e2e integration tests
|
||||
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
|
||||
run: |
|
||||
|
|
|
@ -15,20 +15,27 @@ from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_
|
|||
|
||||
# Available test configs.
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
||||
LazyTensorCoreTestConfig,
|
||||
LinalgOnTensorsBackendTestConfig,
|
||||
MhloBackendTestConfig,
|
||||
NativeTorchTestConfig,
|
||||
TorchScriptTestConfig,
|
||||
TosaBackendTestConfig,
|
||||
EagerModeTestConfig
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('-c', '--config',
|
||||
choices=config_choices,
|
||||
|
@ -36,6 +43,7 @@ def _get_argparse():
|
|||
help=f'''
|
||||
Meaning of options:
|
||||
"refbackend": run through torch-mlir's RefBackend.
|
||||
"mhlo": run through torch-mlir's default MHLO backend.
|
||||
"tosa": run through torch-mlir's default TOSA backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
|
@ -78,6 +86,9 @@ def main():
|
|||
if args.config == 'tosa':
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
if args.config == 'mhlo':
|
||||
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
|
||||
xfail_set = all_test_unique_names - MHLO_PASS_SET
|
||||
elif args.config == 'native_torch':
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = {}
|
||||
|
|
|
@ -21,6 +21,139 @@ EAGER_MODE_XFAIL_SET = {
|
|||
"Matmul_vecmat"
|
||||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"NumelModule_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"MeanModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Convolution2DStaticModule_basic",
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseBinaryStaticShapeModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"BoolTensorReturnFalseModule_basic",
|
||||
"BoolTensorReturnTrueModule_basic",
|
||||
"BoolTensorReturnMixedModule_basic",
|
||||
"SqueezeModule_static",
|
||||
"TModuleRank1_basic",
|
||||
"TModuleRank0_basic",
|
||||
"ElementwiseToDtypeIdentityModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"SqueezeDimModule_static",
|
||||
"SqueezeDimModule_identity",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"Matmul4dStatic_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_2d",
|
||||
"Matmul_matvec",
|
||||
"Matmul_vecmat",
|
||||
"MaxPool2dWithIndicesStaticModule_basic",
|
||||
"MmDagModule_basic",
|
||||
"MmModule_basic",
|
||||
"MmModule_chained",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"ZerosModuleDefaultDtype_basic",
|
||||
"ZerosModuleInt2D_basic",
|
||||
"ZerosModuleInt3D_basic",
|
||||
"ZerosModuleFloat2D_basic",
|
||||
"ZerosModuleFloat3D_basic",
|
||||
"ZerosModuleFalsePinMemory_basic",
|
||||
"OnesModuleDefaultDtype_basic",
|
||||
"OnesModuleInt_basic",
|
||||
"OnesModuleFloat_basic",
|
||||
"OnesModuleFalsePinMemory_basic",
|
||||
"NewZerosModuleDefaultDtype_basic",
|
||||
"NewZerosModuleInt2D_basic",
|
||||
"NewZerosModuleInt3D_basic",
|
||||
"NewZerosModuleFloat2D_basic",
|
||||
"NewZerosModuleFloat3D_basic",
|
||||
"NewZerosModuleFalsePinMemory_basic",
|
||||
"NewOnesModuleDefaultDtype_basic",
|
||||
"NewOnesModuleInt2D_basic",
|
||||
"NewOnesModuleInt3D_basic",
|
||||
"NewOnesModuleFloat2D_basic",
|
||||
"NewOnesModuleFloat3D_basic",
|
||||
"NewOnesModuleFalsePinMemory_basic",
|
||||
"DropoutEvalIntModule_basic",
|
||||
"DropoutEvalFloatModule_basic",
|
||||
"ContiguousModule_basic",
|
||||
"DropoutModule_basic",
|
||||
"ViewCollapseModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewDynamicExpandCollapseModule_basic",
|
||||
"ViewDynamicExpandModule_basic",
|
||||
"ViewExpandModule_basic",
|
||||
"ViewExpandOnesModule_basic",
|
||||
"ViewExpandOnesBeforeAndAfterModule_basic",
|
||||
"ViewExpandOnesMiddleModule_basic",
|
||||
"ViewExpandCollapseModule_basic",
|
||||
"ViewExpandCollapseWithOnesModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
"ViewNoChange3dModule_basic",
|
||||
"UnsafeViewExpandModule_basic",
|
||||
"ReduceMaxAllDims_basic",
|
||||
"ReduceMaxFloatModule_basic",
|
||||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
"ReduceSumDimIntListIntModule_basic",
|
||||
"ReduceSumFloatModule_basic",
|
||||
"ReduceSumSignedIntModule_basic",
|
||||
"ReduceSumUnsignedIntModule_basic",
|
||||
"RepeatModule_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"BaddbmmStaticModule_basic",
|
||||
"BaddbmmBroadcast1DInputModule_basic",
|
||||
"BaddbmmBroadcast2DInputModule_basic",
|
||||
"NarrowHorizontalTest2_basic",
|
||||
"NarrowHorizontalTest_basic",
|
||||
"NarrowVerticalTest2_basic",
|
||||
"NarrowVerticalTest_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"NumpyTRank1Module_basic",
|
||||
"NumpyTRank2Module_basic",
|
||||
"NumpyTRankNStaticModule_basic",
|
||||
"NumpyTRankNDynamicModule_basic",
|
||||
"TModuleRank2_basic",
|
||||
"TensorLiteralModule_basic",
|
||||
"TensorsConcatModule_basic",
|
||||
"TensorOpaqueLiteralModule_basic",
|
||||
"TransposeIntModule_basic",
|
||||
"TransposeIntNegDimsModule_basic",
|
||||
"OnesModuleCPUDevice_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"UnsafeViewCollapseModule_basic",
|
||||
"UnsafeViewDynamicExpandModule_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
|
||||
#include "torch-mlir/Conversion/Passes.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
|
@ -25,4 +28,11 @@ namespace {
|
|||
#include "torch-mlir/Conversion/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::torch::registerConversionPasses() { ::registerPasses(); }
|
||||
void mlir::torch::registerConversionPasses() {
|
||||
::registerPasses();
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
||||
});
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
}
|
||||
|
|
|
@ -977,8 +977,41 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
v = mhlo::promoteType(rewriter, v, outType);
|
||||
}
|
||||
|
||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
||||
op, ValueRange(builtinTensors), static_cast<uint64_t>(dim));
|
||||
op, ValueRange(builtinTensors), posDim);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenNumelOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||
AtenNumelOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||
size_t rank = selfTy.getRank();
|
||||
|
||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||
auto loc = op->getLoc();
|
||||
Value numel =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
||||
for (size_t d = 0 ; d < rank; ++ d) {
|
||||
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
||||
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
|
||||
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
|
||||
}
|
||||
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
if (outTy != numel.getType()) {
|
||||
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
|
||||
op, outTy, numel);
|
||||
} else {
|
||||
rewriter.replaceOp(op, numel);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1067,5 +1100,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNumelOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|||
DEPENDS
|
||||
MhloDialect
|
||||
ChloDialect
|
||||
MhloToLinalg
|
||||
MLIRMhloPassIncGen
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
|
@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|||
MLIRPass
|
||||
MhloDialect
|
||||
ChloDialect
|
||||
MhloToLinalg
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
|
|
|
@ -490,6 +490,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
||||
}
|
||||
if (inputDims.size() == 0) {
|
||||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
}
|
||||
|
||||
for (auto d : inputDims) {
|
||||
d = toPositiveDim(d, inputTy.getRank());
|
||||
|
|
|
@ -256,6 +256,14 @@ public:
|
|||
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
numel);
|
||||
|
||||
if (dimSizes.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.self());
|
||||
return success();
|
||||
}
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
||||
loc, mhloShape.getType(), numel, mhloShape);
|
||||
|
@ -310,6 +318,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
if (dSize != 1)
|
||||
dims.push_back(r);
|
||||
}
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
|
@ -354,6 +367,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
SmallVector<int64_t, 4> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
dims.erase(dims.begin() + dim);
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -202,7 +202,6 @@ def compile(model: torch.nn.Module,
|
|||
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
||||
else:
|
||||
scripted = torch.jit.script(model)
|
||||
|
||||
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
|
||||
arg_placeholders = []
|
||||
for arg in example_args:
|
||||
|
@ -240,7 +239,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
""") from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
||||
|
||||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import abc
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.ir import Module
|
||||
|
||||
# A type shared between the result of `MhloBackend.compile` and the
|
||||
# input to `MhloBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
|
||||
|
||||
class MhloBackend(abc.ABC):
|
||||
"""The interface to an MHLO backend.
|
||||
|
||||
Backends are recommended to raise meaningful exceptions in case of error,
|
||||
ideally with easy reproduction instructions.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
||||
The module adheres to the MHLO backend contract
|
||||
(see the VerifyMhloBackendContract pass).
|
||||
|
||||
The compiled artifact can be any type, but must be correctly
|
||||
interpreted by the `load` method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||
"""Load the compiled artifact into a uniformly invokable form.
|
||||
|
||||
The compiled artifact is the result of a previous call to `compile`.
|
||||
|
||||
See the description of `Invoker` for the requirements on the returned
|
||||
type.
|
||||
"""
|
|
@ -0,0 +1,45 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from torch_mlir.ir import *
|
||||
from torch_mlir.passmanager import *
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from .abc import MhloBackend
|
||||
|
||||
__all__ = [
|
||||
"LinalgOnTensorsMhloBackend",
|
||||
]
|
||||
|
||||
class LinalgOnTensorsMhloBackend(MhloBackend):
|
||||
"""Main entry-point for the linalg-on-tensors based MHLO backend.
|
||||
|
||||
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module that satisfied the MHLO backend contract.
|
||||
|
||||
Args:
|
||||
imported_module: The MLIR module consisting of funcs in the MHLO
|
||||
dialect.
|
||||
Returns:
|
||||
An opaque, backend specific compiled artifact object that can be
|
||||
passed to `load`.
|
||||
"""
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"func.func(hlo-legalize-to-linalg)",
|
||||
"Lowering MLIR-HLO to Linalg-on-Tensors")
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
def load(self, module):
|
||||
"""Loads a compiled artifact into the runtime."""
|
||||
return self.refbackend.load(module)
|
|
@ -7,5 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
|
|||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
from .mhlo_backend import MhloBackendTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
from .eager_mode import EagerModeTestConfig
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import (
|
||||
TestConfig,
|
||||
Trace,
|
||||
TraceItem
|
||||
)
|
||||
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
)
|
||||
|
||||
|
||||
class MhloBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||
|
||||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
def __init__(self, backend: MhloBackend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(
|
||||
program, example_args, output_type="mhlo")
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
return result
|
Loading…
Reference in New Issue