add ci tests (#754)

pull/905/head
Maksim Levental 2022-05-25 14:59:59 -05:00 committed by GitHub
parent 24e04d5729
commit cec5aeedb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 758 additions and 517 deletions

View File

@ -44,6 +44,11 @@ jobs:
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=refbackend -v
- name: EagerMode - TorchScript end-to-end tests
run: |
cd $GITHUB_WORKSPACE
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=eager_mode -v
- name: TOSA backend - TorchScript end-to-end tests
run: |
cd $GITHUB_WORKSPACE

View File

@ -14,12 +14,11 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
# These fail because an upstream pytorch bug; more information at the following issue
# https://github.com/pytorch/pytorch/issues/74400
"ElementwiseMulScalarModule_basic",
"ElementwiseSubScalarIntModule_basic",
})
EAGER_MODE_XFAIL_SET = {
# RefBackend fails
"TableBatchEmbeddingModule_basic",
"QuantizedMLP_basic"
}
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.

View File

@ -1,99 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
from framework import run_test
from torch_mlir.eager_mode.torch_mlir_dispatch import (
annotate_args_kwargs,
normalize_args_kwargs,
build_script_function,
)
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
# -----
# CHECK: PASS - simple
@run_test
def simple():
target = torch.ops.aten.addmm.default
A = torch.randn(1, 3, 32, 32)
B = torch.randn(1, 3, 32, 32)
C = torch.randn(1, 3, 32, 32)
args = (A, B, C)
kwargs = dict(beta=1, alpha=1)
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
script_fun = build_script_function(target._schema, new_args, new_kwargs)
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
for annot in annotations:
print(annot)
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
# -----
# CHECK: PASS - handle_zero_dim
@run_test
def handle_zero_dim():
target = torch.ops.aten.addmm.default
A = torch.randn(0, 3, 32, 32)
B = torch.randn(0, 3, 32, 32)
C = torch.randn(0, 3, 32, 32)
args = (A, B, C)
kwargs = dict(beta=1, alpha=1)
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
script_fun = build_script_function(target._schema, new_args, new_kwargs)
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
for annot in annotations:
print(annot)
# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
# -----
# CHECK: PASS - correctly_order_kwargs
@run_test
def correctly_order_kwargs():
target = torch.ops.aten.native_batch_norm.out
input = torch.randn(2, 5, 2, 3)
weight = torch.randn(5)
bias = torch.randn(5)
running_mean = torch.randn(5)
running_var = torch.randn(5)
args = (input, weight, bias, running_mean, running_var)
out = torch.empty_like(input)
save_mean = torch.empty_like(running_mean)
save_invstd = torch.empty_like(running_var)
kwargs = dict(
training=False,
momentum=0.1,
eps=0.0001,
out=out,
save_mean=save_mean,
save_invstd=save_invstd,
)
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
script_fun = build_script_function(target._schema, new_args, new_kwargs)
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
for annot in annotations:
print(annot)

View File

@ -9,7 +9,7 @@
import torch
from framework import run_test
from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function
from torch_mlir.eager_mode.ir_building import build_ts_script_function
# CHECK: graph(%[[A1:.*]] : Tensor,
@ -24,13 +24,15 @@ from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function
@run_test
def simple():
target = torch.ops.aten.addmm.default
A = torch.randn(1, 3, 32, 32)
B = torch.randn(1, 3, 32, 32)
C = torch.randn(1, 3, 32, 32)
args = (A, B, C)
kwargs = dict(beta=1, alpha=1)
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
mat1=torch.randn(1, 3, 32, 32),
mat2=torch.randn(1, 3, 32, 32),
beta=1,
alpha=1,
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -50,11 +52,10 @@ def simple():
@run_test
def handle_optional_tensor_input():
target = torch.ops.aten.convolution.default
input = torch.randn(1, 3, 32, 32)
weight = torch.randn(3, 3, 3, 3)
bias = torch.randn(3)
args = (input, weight, bias)
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
@ -62,20 +63,19 @@ def handle_optional_tensor_input():
output_padding=[0, 0],
groups=1,
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: FAIL - fail_not_enough_args
# CHECK: Errors: tuple index out of range
# CHECK: Errors: 'groups'
@run_test
def fail_not_enough_args():
target = torch.ops.aten.convolution.default
input = torch.randn(1, 3, 32, 32)
weight = torch.randn(3, 3, 3, 3)
bias = torch.randn(3)
args = (input, weight, bias)
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
@ -83,41 +83,41 @@ def fail_not_enough_args():
output_padding=[0, 0],
# Missing groups=1,
)
build_script_function(target._schema, args, kwargs)
build_ts_script_function(target._schema, kwargs)
# CHECK: PASS - simple_args_or_kwargs
# CHECK: graph(%input : Tensor,
# CHECK: %weight : Tensor,
# CHECK: %bias : Tensor):
# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %7 : bool = prim::Constant[value=0]()
# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %9 : int = prim::Constant[value=1]()
# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9)
# CHECK: return (%0)
# -----
# CHECK: PASS - simple_kwargs
@run_test
def simple_args_or_kwargs():
def simple_kwargs():
target = torch.ops.aten.convolution.default
input = torch.randn(1, 3, 32, 32)
weight = torch.randn(3, 3, 3, 3)
bias = torch.randn(3)
stride = [1, 1]
padding = [0, 0]
dilation = [1, 1]
transposed = False
output_padding = [0, 0]
groups = 1
script_fun1 = build_script_function(
script_fun1 = build_ts_script_function(
target._schema,
(input, weight, bias),
dict(
stride=stride,
padding=padding,
dilation=dilation,
transposed=transposed,
output_padding=output_padding,
groups=groups,
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
),
)
script_fun2 = build_script_function(
target._schema,
(input, weight, bias, stride, padding, dilation),
dict(transposed=transposed, output_padding=output_padding, groups=groups),
)
assert str(script_fun1.graph) == str(script_fun2.graph)
print(script_fun1.graph)
# CHECK: graph(%[[C2:.*]] : Tensor):
@ -134,15 +134,16 @@ def simple_args_or_kwargs():
def handle_empty_lists():
target = torch.ops.aten.max_pool2d_with_indices.default
# print(target._schema)
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {
"kernel_size": [3, 3],
"stride": [],
"padding": [0, 0],
"dilation": [1, 1],
"ceil_mode": False,
}
script_fun = build_script_function(target._schema, args, kwargs)
input = torch.randn((1, 3, 32, 32))
kwargs = dict(
input=input,
kernel_size=[3, 3],
stride=[],
padding=[0, 0],
dilation=[1, 1],
ceil_mode=False,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -160,15 +161,15 @@ def handle_empty_lists():
def handle_nones():
target = torch.ops.aten.max_pool2d_with_indices.default
# print(target._schema)
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {
"kernel_size": [3, 3],
"stride": None,
"padding": [0, 0],
"dilation": [1, 1],
"ceil_mode": False,
}
script_fun = build_script_function(target._schema, args, kwargs)
kwargs = dict(
input=torch.randn((1, 3, 32, 32)),
kernel_size=[3, 3],
stride=None,
padding=[0, 0],
dilation=[1, 1],
ceil_mode=False,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -188,11 +189,10 @@ def handle_nones():
@run_test
def handle_optional_tensors():
target = torch.ops.aten.convolution.default
input = torch.randn(1, 3, 32, 32)
weight = torch.randn(3, 3, 3, 3)
bias = torch.randn(3)
args = (input, weight, bias)
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
@ -200,7 +200,7 @@ def handle_optional_tensors():
output_padding=[0, 0],
groups=1,
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -217,12 +217,15 @@ def handle_optional_tensors():
@run_test
def handle_ones_like():
target = torch.ops.aten.ones_like.default
input = torch.randn(1, 3, 32, 32)
args = (input,)
kwargs = dict(
dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
input=torch.randn(1, 3, 32, 32),
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -241,13 +244,18 @@ def handle_ones_like():
@run_test
def handle_multiple_outputs():
target = torch.ops.aten.native_batch_norm.default
A = torch.randn(1, 3, 32, 32)
B = torch.randn(1, 3, 32, 32)
C = torch.randn(1, 3, 32, 32)
args = (A, B, C, None, None, False, 1.0, 1.0)
kwargs = dict()
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(1, 3, 32, 32),
bias=torch.randn(1, 3, 32, 32),
running_mean=None,
running_var=None,
training=False,
momentum=1.0,
eps=1.0
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
@ -256,13 +264,18 @@ def handle_multiple_outputs():
@run_test
def check_legal_name():
target = torch.ops.aten.native_batch_norm.default
A = torch.randn(1, 3, 32, 32)
B = torch.randn(1, 3, 32, 32)
C = torch.randn(1, 3, 32, 32)
args = (A, B, C, None, None, False, 1.0, 1.0)
kwargs = dict()
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(1, 3, 32, 32),
bias=torch.randn(1, 3, 32, 32),
running_mean=None,
running_var=None,
training=False,
momentum=1.0,
eps=1.0
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.name)
@ -286,24 +299,22 @@ def correctly_order_kwargs():
target = torch.ops.aten.native_batch_norm.out
input = torch.randn(2, 5, 2, 3)
weight = torch.randn(5)
bias = torch.randn(5)
running_mean = torch.randn(5)
running_var = torch.randn(5)
args = (input, weight, bias, running_mean, running_var)
out = torch.empty_like(input)
save_mean = torch.empty_like(running_mean)
save_invstd = torch.empty_like(running_var)
kwargs = dict(
input=torch.randn(2, 5, 2, 3),
weight=torch.randn(5),
bias=torch.randn(5),
running_mean=running_mean,
running_var=running_var,
training=False,
momentum=0.1,
eps=0.0001,
out=out,
save_mean=save_mean,
save_invstd=save_invstd,
out=torch.empty_like(input),
save_mean=torch.empty_like(running_mean),
save_invstd=torch.empty_like(running_var),
)
script_fun = build_script_function(target._schema, args, kwargs)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)

View File

@ -15,8 +15,8 @@ from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
# CHECK: PASS - should_normalize
@run_test
def should_normalize():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
args = (torch.randn((1, 3, 32, 32)),)
target = torch.ops.aten.max_pool2d_with_indices.default
input = torch.randn((1, 3, 32, 32))
kwargs = {"kernel_size": [3, 3]}
golden = {
"kernel_size": [3, 3],
@ -28,18 +28,18 @@ def should_normalize():
"ceil_mode": False,
}
new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs)
for arg, new_arg in zip(args, new_args):
assert torch.allclose(arg, new_arg)
new_kwargs = normalize_args_kwargs(target, (input,), kwargs)
assert torch.allclose(new_kwargs["input"], input)
for k, v in new_kwargs.items():
if k == "input": continue
assert v == golden[k]
# CHECK: FAIL - shouldnt_normalize1
# CHECK: Couldn't normalize args and kwargs
# CHECK: Errors: missing a required argument: 'kernel_size'
@run_test
def shouldnt_normalize1():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"stride": []}
normalize_args_kwargs(target, args, kwargs)
@ -54,7 +54,7 @@ def shouldnt_normalize1():
# CHECK: XPASS - shouldnt_normalize2
@run_test(XPASS=True)
def shouldnt_normalize2():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": []}
normalize_args_kwargs(target, args, kwargs)
@ -63,7 +63,7 @@ def shouldnt_normalize2():
# CHECK: XPASS - shouldnt_normalize3
@run_test(XPASS=True)
def shouldnt_normalize3():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": [3, 3], "padding": None}
normalize_args_kwargs(target, args, kwargs)

View File

@ -0,0 +1,3 @@
import os
EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't')

View File

@ -22,11 +22,15 @@ to convert the TorchScript function into MLIR using the `torch` dialect.
"""
import abc
from typing import Any, Optional, Iterable
import re
from typing import Any, Optional, Iterable, Dict
from typing import Union
import numpy as np
import torch
from torch.jit import ScriptFunction
import torch._C
import torch.jit
from torch._ops import OpOverload
from torch_mlir import ir
from torch_mlir.dialects.func import FuncOp
@ -202,26 +206,148 @@ def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
return None
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
def is_tensor_type(typ: torch._C.Type):
return typ.isSubtypeOf(torch.TensorType.get()) or (
isinstance(typ, torch.OptionalType)
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
)
def is_list_of_tensors_type(typ: torch._C.Type):
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
def build_ts_script_function(
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
) -> torch.jit.ScriptFunction:
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
Constants are inlined for the purposes of invalidating the compile cache when they change.
Parameters
----------
schema: torch._C.FunctionSchema
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
Returns
-------
torch.jit.ScriptFunction
Fully specialized (all constants) TS graph whose only arguments are tensors.
"""
# Creates empty TS graph.
graph = torch._C.Graph()
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
# Associate graph inputs/outputs with node inputs/outputs.
graph_inputs = []
for arg in schema.arguments:
arg_name = arg.name if arg.name != "self" else "input"
# If arg is a flattened list of tensors, such as in the case of torch.cat
# then add each element of the list to the graph corresponding to arg
# and insert a ListConstruct to function as input to the op.
if is_list_of_tensors_type(arg.type):
inps = []
for kwarg in [
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
]:
inp = graph.addInput()
el_typ = arg.type.getElementType()
if isinstance(el_typ, torch.OptionalType):
el_typ = el_typ.getElementType()
inp.setType(el_typ)
inp.setDebugName(kwarg)
inps.append(inp)
graph_inputs.append(kwarg)
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
list_cons.moveBefore(node)
inp = list_cons.output()
inp.setType(torch.ListType.ofTensors())
# If arg is a tensor, then add input to the graph corresponding to arg.
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
inp = graph.addInput()
if isinstance(arg.type, torch.OptionalType):
el_typ = arg.type.getElementType()
else:
el_typ = arg.type
inp.setType(el_typ)
inp.setDebugName(arg_name)
graph_inputs.append(arg_name)
# If arg is a constant, inline (at the top of the graph).
else:
val = kwargs[arg_name]
if val == []:
# Some ops have empty list default values for args
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
# This might be an upstream bug but there doesn't seem to be a way to
# build a prim::ListConstruct list that's empty.
val = None
inp = graph.insertConstant(val)
inp.node().moveBefore(node)
node.addInput(inp)
# Reorder graph inputs to match kwargs.
permutes = [
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
]
graph.permuteInputs(permutes)
if node.hasMultipleOutputs():
for outp in node.outputs():
graph.registerOutput(outp)
else:
graph.registerOutput(node.output())
fn = torch._C._create_function_from_graph(
f"{name_mangle_regex.sub('', str(graph))}", graph
)
return fn
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
"""Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
op: OpOverload
Callable from the torch.ops.aten module/namespace that has a _schema field.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
Returns
-------
ir.Module
Translation of the input module into an MLIR module
Translation of the input module into an MLIR module.
"""
mb = ModuleBuilder()
mb.import_function(jit_function)
func_op = get_func_op_with_name(mb.module, jit_function.name)
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
annotations = []
for arg_name, arg in kwargs.items():
if isinstance(arg, torch.Tensor):
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
annotations = tuple(annotations)
script_fun = build_ts_script_function(op._schema, kwargs)
assert len(annotations) == len(
list(script_fun.graph.inputs())
), "Number of annotations and number of graph inputs differs."
mb = ModuleBuilder()
mb.import_function(script_fun)
func_op = get_func_op_with_name(mb.module, script_fun.name)
assert (
func_op is not None
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"

View File

@ -2,21 +2,20 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from __future__ import annotations
from typing import Any, Callable, Tuple, Union
from typing import List, Dict
from typing import Any, Callable, Tuple
from typing import Dict
import numpy as np
import torch
import torch._C
from torch.fx.node import map_aggregate
from torch.fx.operator_schemas import normalize_function, create_type_hint
from torch.utils._pytree import tree_map
from torch_mlir._mlir_libs._mlir.passmanager import PassManager
from torch.fx import immutable_collections
from torch.fx.operator_schemas import (
_torchscript_schema_to_signature,
_args_kwargs_to_normalized_args_kwargs,
)
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
from torch_mlir.dialects import torch as torch_dialect
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
SUPPORTED_OPS = frozenset(
@ -37,153 +36,44 @@ class UnsupportedByTorchMlirEagerMode(Exception):
return self.value
def check_supported_op(schema: torch._C.FunctionSchema) -> bool:
return (
"torch."
+ schema.name.replace("::", ".")
+ ("." + schema.overload_name if schema.overload_name else "")
) in SUPPORTED_OPS
def is_tensor_type(typ: torch._C.Type):
return typ.isSubtypeOf(torch.TensorType.get()) or (
isinstance(typ, torch.OptionalType)
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
)
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
"""Fill in default values for optional args, which are dependent on the schema."""
arg_types = map_aggregate(args, type)
assert isinstance(arg_types, tuple)
arg_types = map_aggregate(map_aggregate(args, type), create_type_hint)
kwarg_types = {
k: create_type_hint(map_aggregate(v, type)) for k, v in kwargs.items()
}
new_args_and_kwargs = normalize_function(
target.op, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
sig = _torchscript_schema_to_signature(target._schema)
_, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
sig, args, kwargs, normalize_to_only_use_kwargs=True
)
assert new_args_and_kwargs, "Couldn't normalize args and kwargs"
new_args, new_kwargs = new_args_and_kwargs
return new_args, new_kwargs
if "self" in new_kwargs:
new_kwargs["input"] = new_kwargs.pop("self")
# Flatten lists of args for ops that takes lists, such as torch.cat.
to_remove = set()
to_add = {}
for k, v in new_kwargs.items():
if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
to_remove.add(k)
for i, vv in enumerate(v):
to_add[f"{k}_flattened_{i}"] = vv
for rem in to_remove:
del new_kwargs[rem]
new_kwargs.update(**to_add)
# Sort here in order to have consistency across TS graph and
# MLIR module.
sorted_kwargs = dict(sorted(new_kwargs.items()))
return immutable_collections.immutable_dict(sorted_kwargs)
def build_script_function(
schema: torch._C.FunctionSchema,
args: List[torch._C.Argument],
kwargs: Dict[str, Any],
) -> torch.jit.ScriptFunction:
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
Constants are inlined for the purposes of invalidating the compile cache when they change.
"""
# Creates empty TS graph.
graph = torch._C.Graph()
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
# Associate graph inputs/outputs with node inputs/outputs.
for i, arg in enumerate(schema.arguments):
# Find value corresponding to schema arg, either in positional or kw args.
kwarg = False
if arg.name in kwargs:
val = kwargs[arg.name]
kwarg = True
else:
val = args[i]
# If arg is a tensor, then add input to the graph corresponding to arg.
if is_tensor_type(arg.type) and val is not None:
inp = graph.addInput()
if isinstance(arg.type, torch.OptionalType):
inp.setType(arg.type.getElementType())
else:
inp.setType(arg.type)
if kwarg:
# Rename for debugging aid.
inp.setDebugName(arg.name)
# If arg is a constant, inline (at the top of the graph).
else:
if val == []:
# Some ops have empty list default values for args
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
# This might be an upstream bug but there doesn't seem to be a way to
# build a prim::ListConstruct list that's empty.
val = None
inp = graph.insertConstant(val)
inp.node().moveBefore(node)
node.addInput(inp)
if node.hasMultipleOutputs():
for outp in node.outputs():
graph.registerOutput(outp)
else:
graph.registerOutput(node.output())
fn = torch._C._create_function_from_graph("f", graph)
return fn
def get_registered_op(op):
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
return registered_op
def annotate_args_kwargs(
script_fun: torch._C.ScriptFunction,
normalized_args: List[Any],
normalized_kwargs: Dict[str, Any],
):
unwrapped_normalized_args = tree_map(
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
normalized_args,
)
unwrapped_normalized_kwargs = tree_map(
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
normalized_kwargs,
)
annotations = []
tensor_args = []
for i, arg in enumerate(unwrapped_normalized_args):
if isinstance(arg, np.ndarray):
# TODO: Remove once size zero dimensions are handled by torch-mlir.
shape = tuple(map(lambda x: x or -1, arg.shape))
annotations.append(
TorchTensorType(shape=shape, dtype=normalized_args[i].dtype)
)
tensor_args.append(arg)
# Pull out tensor kwargs and put them in positional order.
tensor_kwargs_flat = []
if unwrapped_normalized_kwargs:
tensor_kwargs = {}
arg_idxs = {
arg_name: i
for i, arg_name in enumerate(
[arg.name for arg in script_fun.schema.arguments]
)
}
for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()):
if isinstance(arg, np.ndarray):
tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype)
for _i, (arg, arg_dtype) in sorted(tensor_kwargs.items()):
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype))
tensor_kwargs_flat.append(arg)
return annotations, tensor_args, tensor_kwargs_flat
def write_back_to_mutable(
registered_op: Dict,
out: Union[np.ndarray, List[np.ndarray]],
all_tensor_args: List[np.ndarray],
):
def check_get_aliased_arg(func: Callable,):
"""Write back to mutable args that aren't properly handled otherwise.
Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support
ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
Because of how we pass values to the backend we don't currently support ops that mutate operands.
That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
conventional pytorch op variant would write to.
@ -191,82 +81,31 @@ def write_back_to_mutable(
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
mutate running_mean and running_var). We don't currently handle those.
"""
registered_op = get_registered_op(func)
if not registered_op["is_mutable"]:
return None
if len(registered_op["returns"]) > 1:
raise UnsupportedByTorchMlirEagerMode(
"TorchMLIR doesn't handle multiple aliased returns yet."
)
else:
aliased_arg = next(
arg
for arg in registered_op["arguments"]
if "alias_info" in arg and arg["alias_info"]["is_write"]
)
assert (
"alias_info" in registered_op["returns"][0]
and registered_op["returns"][0]["alias_info"]["is_write"]
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
and registered_op["returns"][0]["alias_info"]["after"][0]
)
assert (
len(aliased_arg["alias_info"]["after"]) == 1
and aliased_arg["alias_info"]["after"][0]
== registered_op["returns"][0]["alias_info"]["after"][0]
)
np.copyto(all_tensor_args[0], out)
return out
def try_torch_mlir_eager(op, args, kwargs, backend):
if hasattr(op, "op_name"):
op_name = op.op_name
elif hasattr(op, "__name__"):
# Handle builtin_function_or_method.
op_name = op.__name__
else:
raise RuntimeError(f"op {op} has no name")
if "detach" in op_name:
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
raise UnsupportedByTorchMlirEagerMode("detaching")
if not hasattr(op, "_schema"):
raise RuntimeError(f"op {op} has no schema.")
new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs)
if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{new_kwargs['layout']} layout not supported."
)
if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{new_kwargs['memory_format']} memory format not supported."
)
script_fun = build_script_function(op._schema, new_args, new_kwargs)
annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs(
script_fun, new_args, new_kwargs
aliased_arg = next(
arg
for arg in registered_op["arguments"]
if "alias_info" in arg and arg["alias_info"]["is_write"]
)
assert (
"alias_info" in registered_op["returns"][0]
and registered_op["returns"][0]["alias_info"]["is_write"]
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
and registered_op["returns"][0]["alias_info"]["after"][0]
)
assert (
len(aliased_arg["alias_info"]["after"]) == 1
and aliased_arg["alias_info"]["after"][0]
== registered_op["returns"][0]["alias_info"]["after"][0]
)
eager_module = build_module(script_fun, annotations)
with eager_module.context:
pm = PassManager.parse(
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
)
pm.run(eager_module)
compiled_module = backend.compile(eager_module)
loaded_module = backend.load(compiled_module)
op_mlir_backend_callable = getattr(loaded_module, script_fun.name)
assert (
op_mlir_backend_callable is not None
), f"Couldn't find function {script_fun.name} in module."
all_tensor_args = np_tensor_args + np_tensor_kwargs_flat
out = op_mlir_backend_callable(*all_tensor_args)
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
if registered_op["is_mutable"]:
out = write_back_to_mutable(registered_op, out, all_tensor_args)
return out
return aliased_arg["name"] if aliased_arg["name"] != "self" else "input"

View File

@ -0,0 +1,102 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import abc
from dataclasses import dataclass
from typing import TypeVar, Tuple, Callable, List, Dict, Any
import torch
from torch_mlir._mlir_libs._mlir.ir import Module
# TODO: This might need to be an ABC too, such as
# to support finding the backend that created the tensor.
DeviceTensor = TypeVar("DeviceTensor")
@dataclass(frozen=True)
class TensorMetaData:
"""A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or
otherwise) that branches on these attributes.
There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here
are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that,
in principle, these should be as close as possible to true; see
https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92
The defaults (properties) simplify the api and seem to work after some testing but
might malfunction in unexpected ways.
# TODO: revisit these assumptions
"""
size: Tuple[int]
dtype: torch.dtype
requires_grad: bool
strides: Tuple[int]
storage_offset: int = 0
layout: torch.layout = torch.strided
device: torch.device = torch.device("cpu")
def __init__(
self,
size,
dtype,
requires_grad,
strides=None,
storage_offset=None,
layout=None,
device=None,
):
super().__init__()
object.__setattr__(self, "size", size)
object.__setattr__(self, "dtype", dtype)
object.__setattr__(self, "requires_grad", requires_grad)
object.__setattr__(
self, "strides", strides if strides is not None else len(size) * [0]
)
object.__setattr__(
self, "storage_offset", storage_offset if storage_offset is not None else 0
)
object.__setattr__(
self, "layout", layout if layout is not None else torch.strided
)
object.__setattr__(
self, "device", device if device is not None else torch.device("cpu")
)
class TorchMLIREagerBackend(abc.ABC):
@abc.abstractmethod
def compile(
self, module: Module
) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]:
raise NotImplementedError
@abc.abstractmethod
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor:
"""Unwrap the backend representation in order to build a torch.Tensor."""
raise NotImplementedError
@abc.abstractmethod
def get_torch_metadata(
self, tensor: DeviceTensor, kwargs: Dict[str, Any]
) -> TensorMetaData:
"""Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build
wrapper tensor."""
raise NotImplementedError
@abc.abstractmethod
def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor:
"""If compilation fails for some reason then device specific representations need to be munged into a
torch.Tensor representation.
"""
raise NotImplementedError
@abc.abstractmethod
def copy_into(self, dst: DeviceTensor, src: DeviceTensor):
"""This method is needed for things like handling aliased arguments."""
raise NotImplementedError

View File

@ -2,26 +2,65 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import contextlib
import re
import traceback
import warnings
from typing import Any
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.ir_building import build_mlir_module
from torch_mlir.eager_mode.torch_mlir_dispatch import (
try_torch_mlir_eager,
UnsupportedByTorchMlirEagerMode,
normalize_args_kwargs,
check_get_aliased_arg,
)
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
@contextlib.contextmanager
def no_dispatch():
"""Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within
__torch_dispatch__."""
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
backend = EagerModeRefBackend()
UNSUPPORTED_OPS = re.compile(
"|".join([
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
"detach",
# We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number.
"_local_scalar_dense",
# https://github.com/llvm/torch-mlir/issues/878
"_unsafe_view",
"view",
])
)
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
class TorchMLIRTensor(torch.Tensor):
"""Wrap torch.Tensor in order to dispatch through torch-mlir instead of aten.
"""This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten.
This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__
in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap
torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module.
It defers device specific behavior to device specific implementations. The deriving classes use the
make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch
through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`)
and are expected to provide conversion mechanisms between their representation and torch.Tensor.
More documentation on how this pattern works can be found in this forum post
Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and
strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the
deriving classes and dispatch back to conventional PyTorch.
More documentation on how the __torch_dispatch__ pattern works can be found in this forum post
https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
and this RFC
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
@ -29,83 +68,188 @@ class TorchMLIRTensor(torch.Tensor):
https://github.com/albanD/subclass_zoo
"""
elem: torch.Tensor
elem: Any
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
# Only float tensors can have gradients.
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
and (kwargs.get("requires_grad", False) or elem.requires_grad),
)
r.elem = elem.detach() if r.requires_grad else elem
def __new__(cls, elem, **kwargs):
"""Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass.
Critically, this method needs to parse relevant metadata from the device representation
(such as shape, striding, dtype, etc.) and translate it into torch conventions.
Deriving classes must provide a way to construct themselves from either their device specific representation
or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error.
"""
if kwargs.get("constructing_from_device_tensor", False):
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
r = make_bare_wrapper_subclass(
cls=cls,
size=tensor_meta_data.size,
strides=tensor_meta_data.strides,
storage_offset=tensor_meta_data.storage_offset,
dtype=tensor_meta_data.dtype,
layout=tensor_meta_data.layout,
device=tensor_meta_data.device,
requires_grad=tensor_meta_data.requires_grad,
)
r.elem = elem
elif isinstance(elem, torch.nn.Parameter):
r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs)
r.elem = backend.transfer_from_torch_to_device(elem.detach().data)
elif isinstance(elem, torch.Tensor):
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
r.elem = backend.transfer_from_torch_to_device(elem)
# This branch handles the case when a python scalar is passed to some op
# or is returned from some aten op, such as _local_scalar_dense.
elif isinstance(elem, (int, float, bool)):
return elem
else:
raise ValueError(f"Unknown element type: {type(elem)}")
return r
def __repr__(self):
if self.grad_fn:
return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})"
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
else:
return f"TorchMLIRTensor({self.elem})"
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})"
@classmethod
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
requires_grad = False
def check_grad(e):
nonlocal requires_grad
if isinstance(e, TorchMLIRTensor):
requires_grad |= e.requires_grad
tree_map(check_grad, args)
tree_map(check_grad, kwargs)
def unwrap(e):
if isinstance(e, TorchMLIRTensor):
return e.elem
if isinstance(e, torch.nn.Parameter):
return e.detach()
return e
def wrap(e):
nonlocal requires_grad
return (
TorchMLIRTensor(e, requires_grad=requires_grad)
if isinstance(e, torch.Tensor)
else e
)
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
requires_grad = check_requires_grad(*args, **kwargs)
try:
out = try_torch_mlir_eager(
func,
unwrapped_args,
unwrapped_kwargs,
backend=refbackend.RefBackendLinalgOnTensorsBackend(),
with no_dispatch():
if hasattr(func, "op_name"):
op_name = func.op_name
elif hasattr(func, "__name__"):
# Handle builtin_function_or_method.
op_name = func.__name__
else:
raise RuntimeError(f"op {func} has no name")
if UNSUPPORTED_OPS.match(op_name):
raise UnsupportedByTorchMlirEagerMode(op_name)
if not hasattr(func, "_schema"):
raise RuntimeError(f"op {func} has no schema.")
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
if "layout" in normalized_kwargs and normalized_kwargs[
"layout"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['layout']} layout not supported."
)
if "memory_format" in normalized_kwargs and normalized_kwargs[
"memory_format"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['memory_format']} memory format not supported."
)
eager_module = build_mlir_module(func, normalized_kwargs)
device_tensor_args = [
kwarg.elem
for _, kwarg in normalized_kwargs.items()
if isinstance(kwarg, cls)
]
assert len(eager_module.body.operations[0].arguments) == len(
device_tensor_args
), "Number of parameters and number of arguments differs."
op_mlir_backend_callable = backend.compile(eager_module)
out = op_mlir_backend_callable(*device_tensor_args)
out = tree_map(
lambda x: cls(
x, requires_grad=requires_grad, constructing_from_device_tensor=True
),
out,
)
if isinstance(out, tuple):
out = [torch.from_numpy(o) for o in out]
else:
out = torch.from_numpy(out)
return tree_map(wrap, out)
except Exception as e:
if isinstance(e, UnsupportedByTorchMlirEagerMode):
warnings.warn(
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
)
else:
warnings.warn(
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
)
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
if EAGER_MODE_DEBUG:
warnings.warn(traceback.format_exc())
if isinstance(e, UnsupportedByTorchMlirEagerMode):
warnings.warn(
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
)
else:
warnings.warn(
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
)
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
out = func(*unwrapped_args, **unwrapped_kwargs)
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
maybe_aliased_arg_name = check_get_aliased_arg(func)
if maybe_aliased_arg_name is not None:
backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
return out
@classmethod
def unwrap(cls, e):
"""Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation."""
if isinstance(e, cls):
return backend.transfer_from_device_to_torch(e.elem)
return e
def check_requires_grad(*args, **kwargs):
requires_grad = False
def check_grad(e):
nonlocal requires_grad
if isinstance(e, TorchMLIRTensor):
requires_grad |= e.requires_grad
tree_map(check_grad, args)
tree_map(check_grad, kwargs)
return requires_grad
def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs):
"""Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce
a wrapper subclass.
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
of the device specific implementation.
"""
r = make_bare_wrapper_subclass(
cls=cls,
size=elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
# Only float tensors can have gradients.
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
and (kwargs.get("requires_grad", False) or elem.requires_grad),
)
return r
def make_bare_wrapper_subclass(
*, cls, size, strides, storage_offset, dtype, layout, device, requires_grad
):
"""Convenience method that builds a wrapper subclass.
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
of the device specific implementation.
"""
return torch.Tensor._make_wrapper_subclass(
cls,
size,
strides=strides,
storage_offset=storage_offset,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)

View File

@ -0,0 +1,91 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from __future__ import annotations
from typing import Dict, Any
import numpy as np
import torch
from torch_mlir.compiler_utils import (
get_module_name_for_debug_dump,
run_pipeline_with_repro_report,
)
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
TorchMLIREagerBackend,
TensorMetaData,
)
from torch_mlir.ir import Module
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
NUMPY_TO_TORCH_DTYPE_DICT = {
np.bool: torch.bool,
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
_ref_backend = RefBackendLinalgOnTensorsBackend()
class EagerModeRefBackend(TorchMLIREagerBackend):
"""Main entry-point for the reference backend for eager mode.
RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping
and munging here is done to between torch.Tensor and numpy.ndarray.
"""
module_to_refbackend_invoker = {}
def get_torch_metadata(
self, tensor: np.ndarray, kwargs: Dict[str, Any]
) -> TensorMetaData:
return TensorMetaData(
size=tensor.shape,
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
requires_grad=tensor.dtype in {np.float, np.float32, np.float64}
and kwargs.get("requires_grad", False),
)
def compile(self, imported_module: Module):
"""Lower the imported TS module to linalg and then further compile for the reference backend and then call."""
fn_name = get_module_name_for_debug_dump(imported_module)
module_hash = str(imported_module)
if module_hash not in self.module_to_refbackend_invoker:
run_pipeline_with_repro_report(
imported_module,
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
"EagerMode",
)
self.module_to_refbackend_invoker[module_hash] = _ref_backend.load(
_ref_backend.compile(imported_module)
)
ref_backend_invoker = self.module_to_refbackend_invoker[module_hash]
op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name)
assert (
op_mlir_backend_callable is not None
), f"Couldn't find function in module."
return op_mlir_backend_callable
def copy_into(self, dst: np.ndarray, src: np.ndarray):
np.copyto(dst, src)
def transfer_from_device_to_torch(self, e: np.ndarray):
return torch.from_numpy(e).clone()
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray:
return tensor.numpy()

View File

@ -15,7 +15,27 @@ def wrap(e):
def unwrap(e):
return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e
return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e
def to_tmt(m: torch.nn.Module):
for buf_name, buf in m.named_buffers(recurse=True):
if isinstance(buf, TorchMLIRTensor):
continue
m.register_buffer(buf_name, TorchMLIRTensor(buf))
for param_name, param in m.named_parameters(recurse=True):
if isinstance(param, TorchMLIRTensor):
continue
m.register_parameter(
param_name,
torch.nn.Parameter(
TorchMLIRTensor(param), requires_grad=param.requires_grad
),
)
for attr in dir(m):
field = getattr(m, attr)
if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor):
setattr(m, attr, TorchMLIRTensor(field))
class EagerModeTestConfig(TestConfig):
@ -25,13 +45,14 @@ class EagerModeTestConfig(TestConfig):
super().__init__()
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
program.apply(to_tmt)
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
attr = artifact
for part in item.symbol.split('.'):
for part in item.symbol.split("."):
attr = getattr(attr, part)
inps = tree_map(wrap, item.inputs)
@ -39,7 +60,6 @@ class EagerModeTestConfig(TestConfig):
output = tree_map(unwrap, outps)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
)
return result