mirror of https://github.com/llvm/torch-mlir
parent
8cad02f87e
commit
2374098d71
|
@ -167,6 +167,12 @@ jobs:
|
||||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||||
python -m e2e_testing.torchscript.main --config=eager_mode -v
|
python -m e2e_testing.torchscript.main --config=eager_mode -v
|
||||||
|
|
||||||
|
- name: Run mhlo e2e integration tests
|
||||||
|
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||||
|
python -m e2e_testing.torchscript.main --config=mhlo -v
|
||||||
|
|
||||||
- name: Run tosa e2e integration tests
|
- name: Run tosa e2e integration tests
|
||||||
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
|
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -15,20 +15,27 @@ from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
from torch_mlir_e2e_test.torchscript.configs import (
|
from torch_mlir_e2e_test.torchscript.configs import (
|
||||||
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
LazyTensorCoreTestConfig,
|
||||||
|
LinalgOnTensorsBackendTestConfig,
|
||||||
|
MhloBackendTestConfig,
|
||||||
|
NativeTorchTestConfig,
|
||||||
|
TorchScriptTestConfig,
|
||||||
|
TosaBackendTestConfig,
|
||||||
|
EagerModeTestConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('-c', '--config',
|
parser.add_argument('-c', '--config',
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -36,6 +43,7 @@ def _get_argparse():
|
||||||
help=f'''
|
help=f'''
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"refbackend": run through torch-mlir's RefBackend.
|
"refbackend": run through torch-mlir's RefBackend.
|
||||||
|
"mhlo": run through torch-mlir's default MHLO backend.
|
||||||
"tosa": run through torch-mlir's default TOSA backend.
|
"tosa": run through torch-mlir's default TOSA backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
|
@ -78,6 +86,9 @@ def main():
|
||||||
if args.config == 'tosa':
|
if args.config == 'tosa':
|
||||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||||
|
if args.config == 'mhlo':
|
||||||
|
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
|
||||||
|
xfail_set = all_test_unique_names - MHLO_PASS_SET
|
||||||
elif args.config == 'native_torch':
|
elif args.config == 'native_torch':
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
xfail_set = {}
|
xfail_set = {}
|
||||||
|
|
|
@ -21,6 +21,139 @@ EAGER_MODE_XFAIL_SET = {
|
||||||
"Matmul_vecmat"
|
"Matmul_vecmat"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MHLO_PASS_SET = {
|
||||||
|
"FlattenStaticModule_basic",
|
||||||
|
"FlattenRank0Module_basic",
|
||||||
|
"TensorsConcatNegativeDimModule_basic",
|
||||||
|
"NumelModule_basic",
|
||||||
|
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||||
|
"SqueezeModule_allUnitDim",
|
||||||
|
"SqueezeDimModule_unitDim",
|
||||||
|
"MeanModule_basic",
|
||||||
|
"MeanDynamicSizesModule_basic",
|
||||||
|
"MeanDimEmptyDimModule_basic",
|
||||||
|
"NumToTensorFloatModule_basic",
|
||||||
|
"AtenToDeviceModule_basic",
|
||||||
|
"AvgPool2dStaticModule_basic",
|
||||||
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
|
"Convolution2DStaticModule_basic",
|
||||||
|
"ElementwiseCloneContiguousModule_basic",
|
||||||
|
"ElementwiseCloneModule_basic",
|
||||||
|
"ElementwiseBinaryStaticShapeModule_basic",
|
||||||
|
"ReturnThreeTensorFloat32_basic",
|
||||||
|
"BoolTensorReturnFalseModule_basic",
|
||||||
|
"BoolTensorReturnTrueModule_basic",
|
||||||
|
"BoolTensorReturnMixedModule_basic",
|
||||||
|
"SqueezeModule_static",
|
||||||
|
"TModuleRank1_basic",
|
||||||
|
"TModuleRank0_basic",
|
||||||
|
"ElementwiseToDtypeIdentityModule_basic",
|
||||||
|
"View1DFoldModule_basic",
|
||||||
|
"UnsafeView1DFoldModule_basic",
|
||||||
|
"SqueezeDimModule_static",
|
||||||
|
"SqueezeDimModule_identity",
|
||||||
|
"SliceModule_basic",
|
||||||
|
"SliceNegIdxModule_basic",
|
||||||
|
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
|
"SliceStartEqEndModule_basic",
|
||||||
|
"SliceSizeTwoStepModule_basic",
|
||||||
|
"SliceWholeTensorModule_basic",
|
||||||
|
"ReturnTwoTensorF32I64_basic",
|
||||||
|
"Matmul4dStatic_basic",
|
||||||
|
"Matmul_dot",
|
||||||
|
"Matmul_2d",
|
||||||
|
"Matmul_matvec",
|
||||||
|
"Matmul_vecmat",
|
||||||
|
"MaxPool2dWithIndicesStaticModule_basic",
|
||||||
|
"MmDagModule_basic",
|
||||||
|
"MmModule_basic",
|
||||||
|
"MmModule_chained",
|
||||||
|
"MaxPool2dStaticModule_basic",
|
||||||
|
"PermuteModule_basic",
|
||||||
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"ZerosModuleDefaultDtype_basic",
|
||||||
|
"ZerosModuleInt2D_basic",
|
||||||
|
"ZerosModuleInt3D_basic",
|
||||||
|
"ZerosModuleFloat2D_basic",
|
||||||
|
"ZerosModuleFloat3D_basic",
|
||||||
|
"ZerosModuleFalsePinMemory_basic",
|
||||||
|
"OnesModuleDefaultDtype_basic",
|
||||||
|
"OnesModuleInt_basic",
|
||||||
|
"OnesModuleFloat_basic",
|
||||||
|
"OnesModuleFalsePinMemory_basic",
|
||||||
|
"NewZerosModuleDefaultDtype_basic",
|
||||||
|
"NewZerosModuleInt2D_basic",
|
||||||
|
"NewZerosModuleInt3D_basic",
|
||||||
|
"NewZerosModuleFloat2D_basic",
|
||||||
|
"NewZerosModuleFloat3D_basic",
|
||||||
|
"NewZerosModuleFalsePinMemory_basic",
|
||||||
|
"NewOnesModuleDefaultDtype_basic",
|
||||||
|
"NewOnesModuleInt2D_basic",
|
||||||
|
"NewOnesModuleInt3D_basic",
|
||||||
|
"NewOnesModuleFloat2D_basic",
|
||||||
|
"NewOnesModuleFloat3D_basic",
|
||||||
|
"NewOnesModuleFalsePinMemory_basic",
|
||||||
|
"DropoutEvalIntModule_basic",
|
||||||
|
"DropoutEvalFloatModule_basic",
|
||||||
|
"ContiguousModule_basic",
|
||||||
|
"DropoutModule_basic",
|
||||||
|
"ViewCollapseModule_basic",
|
||||||
|
"ViewCollapseInferredDimModule_basic",
|
||||||
|
"ViewDynamicExpandCollapseModule_basic",
|
||||||
|
"ViewDynamicExpandModule_basic",
|
||||||
|
"ViewExpandModule_basic",
|
||||||
|
"ViewExpandOnesModule_basic",
|
||||||
|
"ViewExpandOnesBeforeAndAfterModule_basic",
|
||||||
|
"ViewExpandOnesMiddleModule_basic",
|
||||||
|
"ViewExpandCollapseModule_basic",
|
||||||
|
"ViewExpandCollapseWithOnesModule_basic",
|
||||||
|
"ViewExpandInferredDimModule_basic",
|
||||||
|
"ViewNoChangeStaticModule_basic",
|
||||||
|
"ViewNoChange1dModule_basic",
|
||||||
|
"ViewNoChange2dModule_basic",
|
||||||
|
"ViewNoChange3dModule_basic",
|
||||||
|
"UnsafeViewExpandModule_basic",
|
||||||
|
"ReduceMaxAllDims_basic",
|
||||||
|
"ReduceMaxFloatModule_basic",
|
||||||
|
"ReduceMaxSignedIntModule_basic",
|
||||||
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
"ReduceSumDimIntListFloatModule_basic",
|
||||||
|
"ReduceSumDimIntListIntModule_basic",
|
||||||
|
"ReduceSumFloatModule_basic",
|
||||||
|
"ReduceSumSignedIntModule_basic",
|
||||||
|
"ReduceSumUnsignedIntModule_basic",
|
||||||
|
"RepeatModule_basic",
|
||||||
|
"ReshapeAliasCollapseModule_basic",
|
||||||
|
"ReshapeAliasExpandModule_basic",
|
||||||
|
"ReshapeExpandModule_basic",
|
||||||
|
"TestMultipleTensorReturn_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
|
"BaddbmmStaticModule_basic",
|
||||||
|
"BaddbmmBroadcast1DInputModule_basic",
|
||||||
|
"BaddbmmBroadcast2DInputModule_basic",
|
||||||
|
"NarrowHorizontalTest2_basic",
|
||||||
|
"NarrowHorizontalTest_basic",
|
||||||
|
"NarrowVerticalTest2_basic",
|
||||||
|
"NarrowVerticalTest_basic",
|
||||||
|
"NumToTensorIntModule_basic",
|
||||||
|
"NumpyTRank0Module_basic",
|
||||||
|
"NumpyTRank1Module_basic",
|
||||||
|
"NumpyTRank2Module_basic",
|
||||||
|
"NumpyTRankNStaticModule_basic",
|
||||||
|
"NumpyTRankNDynamicModule_basic",
|
||||||
|
"TModuleRank2_basic",
|
||||||
|
"TensorLiteralModule_basic",
|
||||||
|
"TensorsConcatModule_basic",
|
||||||
|
"TensorOpaqueLiteralModule_basic",
|
||||||
|
"TransposeIntModule_basic",
|
||||||
|
"TransposeIntNegDimsModule_basic",
|
||||||
|
"OnesModuleCPUDevice_basic",
|
||||||
|
"Permute0RankModule_basic",
|
||||||
|
"UnsafeViewCollapseModule_basic",
|
||||||
|
"UnsafeViewDynamicExpandModule_basic",
|
||||||
|
}
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
|
|
@ -9,6 +9,9 @@
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/Passes.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||||
|
@ -25,4 +28,11 @@ namespace {
|
||||||
#include "torch-mlir/Conversion/Passes.h.inc"
|
#include "torch-mlir/Conversion/Passes.h.inc"
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
void mlir::torch::registerConversionPasses() { ::registerPasses(); }
|
void mlir::torch::registerConversionPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||||
|
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
||||||
|
});
|
||||||
|
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||||
|
}
|
||||||
|
|
|
@ -977,8 +977,41 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
v = mhlo::promoteType(rewriter, v, outType);
|
v = mhlo::promoteType(rewriter, v, outType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
||||||
op, ValueRange(builtinTensors), static_cast<uint64_t>(dim));
|
op, ValueRange(builtinTensors), posDim);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenNumelOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||||
|
AtenNumelOp op,
|
||||||
|
OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
auto self = adaptor.self();
|
||||||
|
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||||
|
size_t rank = selfTy.getRank();
|
||||||
|
|
||||||
|
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value numel =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
for (size_t d = 0 ; d < rank; ++ d) {
|
||||||
|
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
|
||||||
|
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
if (outTy != numel.getType()) {
|
||||||
|
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
|
||||||
|
op, outTy, numel);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOp(op, numel);
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -1067,5 +1100,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenNumelOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MhloDialect
|
MhloDialect
|
||||||
ChloDialect
|
ChloDialect
|
||||||
|
MhloToLinalg
|
||||||
|
MLIRMhloPassIncGen
|
||||||
TorchMLIRConversionPassIncGen
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
|
@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MhloDialect
|
MhloDialect
|
||||||
ChloDialect
|
ChloDialect
|
||||||
|
MhloToLinalg
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -490,6 +490,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
|
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
|
||||||
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
||||||
}
|
}
|
||||||
|
if (inputDims.size() == 0) {
|
||||||
|
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||||
|
}
|
||||||
|
|
||||||
for (auto d : inputDims) {
|
for (auto d : inputDims) {
|
||||||
d = toPositiveDim(d, inputTy.getRank());
|
d = toPositiveDim(d, inputTy.getRank());
|
||||||
|
|
|
@ -256,6 +256,14 @@ public:
|
||||||
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||||
numel);
|
numel);
|
||||||
|
|
||||||
|
if (dimSizes.size() == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||||
|
op,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType()),
|
||||||
|
adaptor.self());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||||
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
||||||
loc, mhloShape.getType(), numel, mhloShape);
|
loc, mhloShape.getType(), numel, mhloShape);
|
||||||
|
@ -310,6 +318,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
if (dSize != 1)
|
if (dSize != 1)
|
||||||
dims.push_back(r);
|
dims.push_back(r);
|
||||||
}
|
}
|
||||||
|
if (dims.size() == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
|
@ -354,6 +367,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
||||||
SmallVector<int64_t, 4> dims(rank);
|
SmallVector<int64_t, 4> dims(rank);
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
dims.erase(dims.begin() + dim);
|
dims.erase(dims.begin() + dim);
|
||||||
|
if (dims.size() == 0) {
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
|
|
@ -202,7 +202,6 @@ def compile(model: torch.nn.Module,
|
||||||
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
||||||
else:
|
else:
|
||||||
scripted = torch.jit.script(model)
|
scripted = torch.jit.script(model)
|
||||||
|
|
||||||
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
|
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
|
||||||
arg_placeholders = []
|
arg_placeholders = []
|
||||||
for arg in example_args:
|
for arg in example_args:
|
||||||
|
@ -240,7 +239,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
""") from None
|
""") from None
|
||||||
finally:
|
finally:
|
||||||
sys.stderr = original_stderr
|
sys.stderr = original_stderr
|
||||||
|
|
||||||
if output_type == OutputType.RAW:
|
if output_type == OutputType.RAW:
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.ir import Module
|
||||||
|
|
||||||
|
# A type shared between the result of `MhloBackend.compile` and the
|
||||||
|
# input to `MhloBackend.load`. Each backend will likely have a
|
||||||
|
# different definition of this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
|
# A wrapper around a backend-specific loaded program representation
|
||||||
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
|
# Torch modules into appropriate lower-level operations.
|
||||||
|
Invoker = TypeVar('Invoker')
|
||||||
|
|
||||||
|
|
||||||
|
class MhloBackend(abc.ABC):
|
||||||
|
"""The interface to an MHLO backend.
|
||||||
|
|
||||||
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
|
ideally with easy reproduction instructions.
|
||||||
|
"""
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
||||||
|
The module adheres to the MHLO backend contract
|
||||||
|
(see the VerifyMhloBackendContract pass).
|
||||||
|
|
||||||
|
The compiled artifact can be any type, but must be correctly
|
||||||
|
interpreted by the `load` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||||
|
"""Load the compiled artifact into a uniformly invokable form.
|
||||||
|
|
||||||
|
The compiled artifact is the result of a previous call to `compile`.
|
||||||
|
|
||||||
|
See the description of `Invoker` for the requirements on the returned
|
||||||
|
type.
|
||||||
|
"""
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from torch_mlir.ir import *
|
||||||
|
from torch_mlir.passmanager import *
|
||||||
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
|
||||||
|
from .abc import MhloBackend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinalgOnTensorsMhloBackend",
|
||||||
|
]
|
||||||
|
|
||||||
|
class LinalgOnTensorsMhloBackend(MhloBackend):
|
||||||
|
"""Main entry-point for the linalg-on-tensors based MHLO backend.
|
||||||
|
|
||||||
|
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
|
||||||
|
def compile(self, imported_module: Module):
|
||||||
|
"""Compiles an imported module that satisfied the MHLO backend contract.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imported_module: The MLIR module consisting of funcs in the MHLO
|
||||||
|
dialect.
|
||||||
|
Returns:
|
||||||
|
An opaque, backend specific compiled artifact object that can be
|
||||||
|
passed to `load`.
|
||||||
|
"""
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
imported_module,
|
||||||
|
"func.func(hlo-legalize-to-linalg)",
|
||||||
|
"Lowering MLIR-HLO to Linalg-on-Tensors")
|
||||||
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
|
def load(self, module):
|
||||||
|
"""Loads a compiled artifact into the runtime."""
|
||||||
|
return self.refbackend.load(module)
|
|
@ -7,5 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
|
||||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
|
from .mhlo_backend import MhloBackendTestConfig
|
||||||
from .tosa_backend import TosaBackendTestConfig
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
from .eager_mode import EagerModeTestConfig
|
from .eager_mode import EagerModeTestConfig
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import (
|
||||||
|
TestConfig,
|
||||||
|
Trace,
|
||||||
|
TraceItem
|
||||||
|
)
|
||||||
|
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
||||||
|
from .utils import (
|
||||||
|
recursively_convert_to_numpy,
|
||||||
|
recursively_convert_from_numpy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MhloBackendTestConfig(TestConfig):
|
||||||
|
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||||
|
|
||||||
|
This class handles all the common lowering that torch-mlir does before
|
||||||
|
reaching the linalg-on-tensors abstraction level.
|
||||||
|
"""
|
||||||
|
def __init__(self, backend: MhloBackend):
|
||||||
|
super().__init__()
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
|
module = torch_mlir.compile(
|
||||||
|
program, example_args, output_type="mhlo")
|
||||||
|
|
||||||
|
return self.backend.compile(module)
|
||||||
|
|
||||||
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
|
backend_module = self.backend.load(artifact)
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||||
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
|
output = recursively_convert_from_numpy(outputs)
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
output=output))
|
||||||
|
return result
|
Loading…
Reference in New Issue