mirror of https://github.com/llvm/torch-mlir
e2e: Enable generate-runtime-verification pass (#3615)
This adds the `generate-runtime-verification` pass into the linalg refbackend, and moves all tests that now abort at runtime into the crash set, sorted by their respective errors. I have fixed on set of errors found that way, which are mismatches between the static dimensions we cast to and the actual dynamic dimensions. This was caused by wrong annotations on the test cases, like in https://github.com/llvm/torch-mlir/pull/3615/files#diff-48bfbf41fcad5fa01b49197d251114f84a2b8de4f1d87ab938a061aedd1419b1R1931pull/3629/head
parent
0314188dbe
commit
334633b738
|
@ -43,9 +43,11 @@ from .xfail_sets import (
|
|||
LINALG_XFAIL_SET,
|
||||
LINALG_CRASHING_SET,
|
||||
MAKE_FX_TOSA_PASS_SET,
|
||||
MAKE_FX_TOSA_CRASHING_SET,
|
||||
STABLEHLO_PASS_SET,
|
||||
STABLEHLO_CRASHING_SET,
|
||||
TOSA_PASS_SET,
|
||||
TOSA_CRASHING_SET,
|
||||
LTC_XFAIL_SET,
|
||||
LTC_CRASHING_SET,
|
||||
TORCHDYNAMO_XFAIL_SET,
|
||||
|
@ -161,11 +163,11 @@ def main():
|
|||
elif args.config == "tosa":
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
crashing_set = set()
|
||||
crashing_set = TOSA_CRASHING_SET
|
||||
elif args.config == "make_fx_tosa":
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
|
||||
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
|
||||
crashing_set = set()
|
||||
crashing_set = MAKE_FX_TOSA_CRASHING_SET
|
||||
elif args.config == "native_torch":
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = set()
|
||||
|
@ -191,7 +193,10 @@ def main():
|
|||
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
|
||||
crashing_set = set()
|
||||
elif args.config == "torchdynamo":
|
||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
# TODO: Enanble runtime verification and extend crashing set.
|
||||
config = TorchDynamoTestConfig(
|
||||
RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
|
||||
)
|
||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||
crashing_set = TORCHDYNAMO_CRASHING_SET
|
||||
elif args.config == "onnx":
|
||||
|
|
|
@ -34,6 +34,31 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
# Runtime op verification: Out of bounds access
|
||||
"AtenDiagEmbedNegOffsetDiag_basic",
|
||||
"AtenDiagEmbedNonDefault4DDiag_basic",
|
||||
"AtenDiagEmbedOffsetDiag_basic",
|
||||
"AtenDiagEmbedRevDimDiag_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
# Runtime op verification: subview is out-of-bounds of the base memref
|
||||
"Conv_Transpose1dModule_basic",
|
||||
"Conv_Transpose1dStaticModule_basic",
|
||||
"Conv_Transpose2dModule_basic",
|
||||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"ConvolutionModule2DTransposeStrided_basic",
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
"GridSamplerBasic4_basic",
|
||||
# Runtime op verification: stride mismatch in memref.cast
|
||||
"ReduceAllDimEmpty_basic",
|
||||
"TraceUnsignedIntModule_empty",
|
||||
"TraceModule_empty",
|
||||
# Crashes due to copy to a smaller destination buffer than the source buffer.
|
||||
"SliceCopyStartGreaterThanDimSize_Module_basic",
|
||||
}
|
||||
|
@ -476,8 +501,11 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"WeightNormInterfaceModule_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_CRASHING_SET = {
|
||||
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||
"HBC_basic",
|
||||
# Runtime op verification: out-of-bounds access
|
||||
"_SoftmaxModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||
|
@ -899,7 +927,6 @@ STABLEHLO_PASS_SET = {
|
|||
"AtenIntBoolOpModule_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenMmFloatTypes_basic",
|
||||
"AtenMmIntTypes_basic",
|
||||
|
@ -1522,6 +1549,16 @@ STABLEHLO_CRASHING_SET = {
|
|||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
# LLVM ERROR: Failed to infer result type(s)
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
"ElementwiseClampTensorIntModule_basic",
|
||||
}
|
||||
|
||||
TOSA_CRASHING_SET = {
|
||||
# Runtime op verification: Out of bounds access
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
|
@ -2010,6 +2047,15 @@ TOSA_PASS_SET = {
|
|||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_CRASHING_SET = TOSA_CRASHING_SET | {
|
||||
# Runtime op verification: static result dims in reassoc group do not divide src dim evenly
|
||||
"FlattenDynamicModule_basic",
|
||||
"ReshapeDynamicModule_basic",
|
||||
"ViewFlattenAndExpandModule_basic",
|
||||
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
||||
"ViewSizeDimLedByExpandedOnesModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (
|
||||
TOSA_PASS_SET
|
||||
| {
|
||||
|
@ -2821,7 +2867,7 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
|||
}
|
||||
|
||||
|
||||
ONNX_CRASHING_SET = {
|
||||
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"ElementwisePreluModule_basic",
|
||||
|
@ -2840,6 +2886,8 @@ ONNX_CRASHING_SET = {
|
|||
"StdCorrectionEmptyDimModule_basic",
|
||||
"VarCorrectionEmptyDimModule_basic",
|
||||
"VarDimEmptyDimModule_basic",
|
||||
# Runtime op verification: rank mismatch in memref.cast
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
|
|
|
@ -134,82 +134,84 @@ class RefBackendInvoker:
|
|||
return invoke
|
||||
|
||||
|
||||
LOWERING_PIPELINE = (
|
||||
"builtin.module("
|
||||
+ ",".join(
|
||||
[
|
||||
# Apply some optimizations. It would be great if MLIR had more useful
|
||||
# optimizations that worked out of the box here.
|
||||
# Note: When measured, this doesn't seem to actually help that much
|
||||
# for the linalg-on-tensors backend.
|
||||
# This is likely because if things are naturally fusable we usually already
|
||||
# emit things in that form from the high level (e.g. single linalg-generic).
|
||||
# Other backends are likely to benefit more.
|
||||
"func.func(linalg-generalize-named-ops)",
|
||||
"func.func(linalg-fuse-elementwise-ops)",
|
||||
"convert-shape-to-std",
|
||||
# MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
|
||||
# to ensure operations on sparse tensors are lowered to loops.
|
||||
"sparse-assembler{direct-out}",
|
||||
"sparsification-and-bufferization",
|
||||
"sparse-storage-specifier-to-llvm",
|
||||
# Buffer deallocation pass does not know how to handle realloc.
|
||||
"func.func(expand-realloc)",
|
||||
# Generalize pad and concat after sparse compiler, as they are handled
|
||||
# differently when the operations involve sparse operand.
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Bufferize.
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
"one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}",
|
||||
"refback-mlprogram-bufferize",
|
||||
"func.func(finalizing-bufferize)",
|
||||
"func.func(buffer-deallocation)",
|
||||
# Buffer-deallocation does not work with the inlined code generated
|
||||
# by sparse tensor dialect.
|
||||
"inline", # inline sparse helper methods where useful
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||
# of unranked memref, and we rewrite the return to actually be a
|
||||
# callback that consumes the return (the final munged function always
|
||||
# returns void at the C level -- we get the return value by providing the
|
||||
# callback).
|
||||
"refback-munge-calling-conventions",
|
||||
# Insert global variable and instruction sequence for getting the next
|
||||
# global seed used in stateful rng.
|
||||
# Lower to LLVM
|
||||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
"func.func(convert-linalg-to-loops)",
|
||||
"func.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
"func.func(refback-expand-ops-for-llvm)",
|
||||
"func.func(arith-expand)",
|
||||
"func.func(convert-math-to-llvm)",
|
||||
# Handle some complex mlir::math ops (e.g. atan2)
|
||||
"convert-math-to-libm",
|
||||
"expand-strided-metadata",
|
||||
"finalize-memref-to-llvm",
|
||||
"lower-affine",
|
||||
"convert-bufferization-to-memref",
|
||||
"finalize-memref-to-llvm",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
"convert-vector-to-llvm",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"convert-complex-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
def lowering_pipeline(generate_runtime_verification: bool):
|
||||
passes = [
|
||||
# Apply some optimizations. It would be great if MLIR had more useful
|
||||
# optimizations that worked out of the box here.
|
||||
# Note: When measured, this doesn't seem to actually help that much
|
||||
# for the linalg-on-tensors backend.
|
||||
# This is likely because if things are naturally fusable we usually already
|
||||
# emit things in that form from the high level (e.g. single linalg-generic).
|
||||
# Other backends are likely to benefit more.
|
||||
"func.func(linalg-generalize-named-ops)",
|
||||
"func.func(linalg-fuse-elementwise-ops)",
|
||||
"convert-shape-to-std",
|
||||
# MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
|
||||
# to ensure operations on sparse tensors are lowered to loops.
|
||||
"sparse-assembler{direct-out}",
|
||||
"sparsification-and-bufferization",
|
||||
"sparse-storage-specifier-to-llvm",
|
||||
# Buffer deallocation pass does not know how to handle realloc.
|
||||
"func.func(expand-realloc)",
|
||||
# Generalize pad and concat after sparse compiler, as they are handled
|
||||
# differently when the operations involve sparse operand.
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Bufferize.
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
"one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}",
|
||||
"refback-mlprogram-bufferize",
|
||||
"func.func(finalizing-bufferize)",
|
||||
"func.func(buffer-deallocation)",
|
||||
# Buffer-deallocation does not work with the inlined code generated
|
||||
# by sparse tensor dialect.
|
||||
"inline", # inline sparse helper methods where useful
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||
# of unranked memref, and we rewrite the return to actually be a
|
||||
# callback that consumes the return (the final munged function always
|
||||
# returns void at the C level -- we get the return value by providing the
|
||||
# callback).
|
||||
"refback-munge-calling-conventions",
|
||||
# Insert global variable and instruction sequence for getting the next
|
||||
# global seed used in stateful rng.
|
||||
# Lower to LLVM
|
||||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
"func.func(convert-linalg-to-loops)",
|
||||
"func.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
]
|
||||
if generate_runtime_verification:
|
||||
passes += ["generate-runtime-verification"]
|
||||
passes += [
|
||||
"func.func(refback-expand-ops-for-llvm)",
|
||||
"func.func(arith-expand)",
|
||||
"func.func(convert-math-to-llvm)",
|
||||
# Handle some complex mlir::math ops (e.g. atan2)
|
||||
"convert-math-to-libm",
|
||||
"expand-strided-metadata",
|
||||
"finalize-memref-to-llvm",
|
||||
"lower-affine",
|
||||
"convert-bufferization-to-memref",
|
||||
"finalize-memref-to-llvm",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
"convert-vector-to-llvm",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"convert-complex-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
]
|
||||
|
||||
return "builtin.module(" + ",".join(passes) + ")"
|
||||
|
||||
|
||||
class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||
"""Main entry-point for the reference backend."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, generate_runtime_verification: bool = True):
|
||||
super().__init__()
|
||||
self.generate_runtime_verification = generate_runtime_verification
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module, with a flat list of functions.
|
||||
|
@ -226,7 +228,7 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
|||
"""
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
LOWERING_PIPELINE,
|
||||
lowering_pipeline(self.generate_runtime_verification),
|
||||
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
||||
enable_ir_printing=False,
|
||||
)
|
||||
|
|
|
@ -37,7 +37,10 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||
# TOOD: Enable runtime verification and fix found bugs.
|
||||
self.refbackend = RefBackendLinalgOnTensorsBackend(
|
||||
generate_runtime_verification=False
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module that satisfied the Stablehlo backend contract.
|
||||
|
|
|
@ -1928,7 +1928,7 @@ class EmptyStridedModule(torch.nn.Module):
|
|||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
([4, 3, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
|
|
|
@ -1658,8 +1658,8 @@ class ElementwiseClampTensorFloatModule(torch.nn.Module):
|
|||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
([1], torch.float32, True),
|
||||
([1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, min, max):
|
||||
|
@ -1688,8 +1688,8 @@ class ElementwiseClampTensorIntModule(torch.nn.Module):
|
|||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
([1], torch.int64, True),
|
||||
([1], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, min, max):
|
||||
|
@ -1741,7 +1741,7 @@ class ElementwiseClampMinTensorFloatModule(torch.nn.Module):
|
|||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
([1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, min):
|
||||
|
@ -1765,7 +1765,7 @@ class ElementwiseClampMinTensorIntModule(torch.nn.Module):
|
|||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
([1], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, min):
|
||||
|
|
|
@ -542,7 +542,7 @@ class AtenItemFpOpModule(torch.nn.Module):
|
|||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float, True),
|
||||
([1], torch.float, True),
|
||||
]
|
||||
)
|
||||
def forward(self, val):
|
||||
|
|
|
@ -175,7 +175,9 @@ def sparse_jit(f, *args, **kwargs):
|
|||
enable_ir_printing=False,
|
||||
)
|
||||
# Compile with reference Linalg backend.
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
# TODO: runtime verification currently fails with 'rank mismatch' on
|
||||
# memref.cast. Need to fix the IR first.
|
||||
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
|
||||
compiled = backend.compile(module)
|
||||
invoker = backend.load(compiled)
|
||||
xargs = []
|
||||
|
|
Loading…
Reference in New Issue