mirror of https://github.com/llvm/torch-mlir
413 lines
13 KiB
Python
413 lines
13 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""Base classes and interfaces for mapping ops to MLIR.
|
|
|
|
The goal of this facility is to define the majority of op mappings by example,
|
|
from the actual Python invocation, using the PyTorch tracer to extract its
|
|
Graph IR, and introspecting that to determine mappings, metadata and types
|
|
for corresponding MLIR definitions. This is not meant to cover everything,
|
|
and it is expected that a number of hard ops should be mapped by hand.
|
|
|
|
The result of building such an OpRegistry should be a data structure that
|
|
can be used to generate ODS and tables for doing systematic imports from
|
|
a PyTorch Graph to corresponding MLIR module.
|
|
|
|
Example usage (fully automatic discovery):
|
|
|
|
>>> r = OpRegistry()
|
|
>>> r.op(torch.add,
|
|
TensorValue("input"),
|
|
TensorValue("other"),
|
|
alpha=ScalarValue()).with_outref_variant()
|
|
"""
|
|
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
import logging
|
|
import random
|
|
import torch
|
|
|
|
__all__ = [
|
|
"SimpleOpMapping",
|
|
"OpRegistry",
|
|
"ScalarValue",
|
|
"TensorOutRef",
|
|
"TensorValue",
|
|
]
|
|
|
|
|
|
def _is_same_value(value1, value2):
|
|
# Tensors are considered via reference equality.
|
|
value1_tensor = isinstance(value1, torch.Tensor)
|
|
value2_tensor = isinstance(value2, torch.Tensor)
|
|
if value1_tensor or value2_tensor:
|
|
return value1 is value2
|
|
|
|
# Everything else is value equality.
|
|
return value1 == value2
|
|
|
|
|
|
def _extract_immediate(node):
|
|
"""Extracts an immediate value from a node.
|
|
|
|
Supported node types:
|
|
prim::Constant
|
|
prim::ListConstruct
|
|
"""
|
|
# Constant
|
|
if node.kind() == "prim::Constant":
|
|
# Try to extract as different types.
|
|
try:
|
|
return node.t("value")
|
|
except RuntimeError:
|
|
pass
|
|
try:
|
|
return node.f("value")
|
|
except RuntimeError:
|
|
pass
|
|
try:
|
|
return node.i("value")
|
|
except RuntimeError:
|
|
pass
|
|
try:
|
|
return node.s("value")
|
|
except RuntimeError:
|
|
pass
|
|
return None
|
|
# List
|
|
elif node.kind() == "prim::ListConstruct":
|
|
return [_extract_immediate(i.node()) for i in node.inputs()]
|
|
else:
|
|
raise ValueError("Unrecognized immediate input node: {!r}".format(node))
|
|
|
|
|
|
class ValueSpec:
|
|
"""Base class for inputs to operations.
|
|
|
|
This binds information about how the input is mapped to the MLIR operation.
|
|
"""
|
|
|
|
def __init__(self, name=None):
|
|
super().__init__()
|
|
self.name = name
|
|
|
|
@property
|
|
def mlir_ods_predicate(self):
|
|
return "AnyType"
|
|
|
|
def generate_example(self, index=0):
|
|
"""Generates an example value."""
|
|
raise NotImplementedError()
|
|
|
|
def __repr__(self):
|
|
return "{}({!r})".format(self.__class__.__name__, self.name)
|
|
|
|
|
|
class TensorValue(ValueSpec):
|
|
"""An input that is a tensor."""
|
|
|
|
def __init__(self, name=None, *, example_size=None):
|
|
super().__init__(name=name)
|
|
if example_size is None:
|
|
example_size = (2, 3, 7) # No significance.
|
|
self.example_size = example_size
|
|
|
|
@property
|
|
def mlir_ods_predicate(self):
|
|
return "ATen_AnyTensor"
|
|
|
|
def generate_example(self, index=0):
|
|
return torch.rand(*self.example_size)
|
|
|
|
|
|
class TensorOutRef(ValueSpec):
|
|
"""A tensor that is passed by ref as an out parameter."""
|
|
|
|
def __init__(self, name=None, *, example_size=None):
|
|
super().__init__(name=name)
|
|
if example_size is None:
|
|
example_size = (2, 3, 7) # No significance.
|
|
self.example_size = example_size
|
|
|
|
@property
|
|
def mlir_ods_predicate(self):
|
|
return "ATen_AnyRefTensor"
|
|
|
|
def generate_example(self, index=0):
|
|
return torch.rand(*self.example_size)
|
|
|
|
|
|
class ScalarValue(ValueSpec):
|
|
"""An input that is a scalar."""
|
|
|
|
def __init__(self, name=None, value=None):
|
|
super().__init__(name=name)
|
|
self.value = value
|
|
|
|
@property
|
|
def mlir_ods_predicate(self):
|
|
return "ATen_AnyScalar"
|
|
|
|
def generate_example(self, index=0):
|
|
if self.value is not None:
|
|
return self.value
|
|
return 1.0 + index # Generates a stable value.
|
|
|
|
|
|
class OpMapping:
|
|
"""Base class for things purporting to map an operation."""
|
|
pass
|
|
|
|
|
|
class SimpleOpMapping(OpMapping):
|
|
"""Maps a PyTorch invocation to its MLIR representation."""
|
|
|
|
def __init__(self, op_f, *op_args, **op_kwargs):
|
|
super().__init__()
|
|
self.op_f = op_f
|
|
self.op_args = op_args
|
|
self.op_kwargs = op_kwargs
|
|
self.outref_variant_value = None # type: Optional[TensorOutRef]
|
|
|
|
# Set after finalize.
|
|
self.op_kind = None # type: Optional[str]
|
|
self.op_arity = -1 # type: int
|
|
self.operand_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
|
|
self.result_map = None # type: Optional[List[Tuple[int, ValueSpec]]]
|
|
self.mlir_operation_name = None # type: Optional[str]
|
|
|
|
def __repr__(self):
|
|
return ("SimpleOp({kind!r}[{arity}] -> {name!s}, operands={operands!r}, "
|
|
"results={results!r})".format(kind=self.op_kind,
|
|
arity=self.op_arity,
|
|
name=self.mlir_operation_name,
|
|
operands=self.operand_map,
|
|
results=self.result_map))
|
|
|
|
def clone(self) -> "SimpleOpMapping":
|
|
copy = SimpleOpMapping(self.op_f, *self.op_args, **self.op_kwargs)
|
|
for name in [
|
|
"outref_variant_value", "op_kind", "op_arity", "operand_map",
|
|
"result_map", "mlir_operation_name"
|
|
]:
|
|
setattr(copy, name, getattr(self, name))
|
|
return copy
|
|
|
|
def with_outref_variant(self, value=None):
|
|
"""Instructs the registry to also generate an outref variant.
|
|
|
|
This is done by cloning the op prior to finalizing and adding an out=
|
|
paramer.
|
|
"""
|
|
self.outref_variant_value = TensorOutRef() if value is None else value
|
|
return self
|
|
|
|
@property
|
|
def all_arg_values(self) -> List[ValueSpec]:
|
|
"""Returns all arg values (either positional or kw)."""
|
|
return list(self.op_args) + list(self.op_kwargs.values())
|
|
|
|
@property
|
|
def is_outref_form(self) -> bool:
|
|
"""Whether the op contains an out parameter that aliases to the result."""
|
|
return any(isinstance(a, TensorOutRef) for a in self.all_arg_values)
|
|
|
|
def generate_example(self) -> Tuple[Tuple, Dict]:
|
|
"""Generates an example signature for invoking the op.
|
|
|
|
Returns:
|
|
(tuple, dict) of positional and keyword args.
|
|
"""
|
|
index = 0
|
|
positional = list()
|
|
kw = dict()
|
|
for op_arg in self.op_args:
|
|
positional.append(op_arg.generate_example(index))
|
|
index += 1
|
|
for kw_name, kw_value in self.op_kwargs.items():
|
|
kw[kw_name] = kw_value.generate_example(index)
|
|
index += 1
|
|
return positional, kw
|
|
|
|
def finalize(self):
|
|
"""Finalizes the mapping once all hints have been applied."""
|
|
# Update the name on all args if undefined.
|
|
for index, op_arg in enumerate(self.op_args):
|
|
if op_arg.name is None:
|
|
op_arg.name = "arg%d".format(index)
|
|
for key, op_arg in self.op_kwargs.items():
|
|
if op_arg.name is None:
|
|
op_arg.name = key
|
|
|
|
# Create an example graph and configure from it.
|
|
self._configure_from_example()
|
|
|
|
# Determine default operation name.
|
|
if self.mlir_operation_name is None:
|
|
self._set_default_mlir_operation_name()
|
|
|
|
def _set_default_mlir_operation_name(self):
|
|
op_ns, op_name = self.op_kind.split("::", maxsplit=1)
|
|
# Since these are emitted into the "aten" dialect namespace, alias them
|
|
# to omit the prefix to distinguish from custom ops and others (which will
|
|
# have a prefix).
|
|
default_name = op_name if op_ns == "aten" else op_ns + "." + op_name
|
|
if op_ns == "aten":
|
|
default_name = op_name
|
|
else:
|
|
default_name = op_ns + "." + op_name
|
|
|
|
if self.is_outref_form:
|
|
default_name += ".inplace"
|
|
self.mlir_operation_name = default_name
|
|
|
|
def _configure_from_example(self):
|
|
# Trace the op so that we get an example graph like this:
|
|
# %0 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
|
|
# %1 : Float(2, 3, 7) = prim::Constant[value=<Tensor>]()
|
|
# %2 : float = prim::Constant[value=3.]()
|
|
# %3 : Float(2, 3, 7) = aten::add(%0, %1, %2)
|
|
# return (%3)
|
|
# The def of the return value is expected to be the modeled op. The
|
|
# inputs to that op are expected to be captured constants that can be
|
|
# re-associated to the example inputs.
|
|
example_args, example_kwargs = self.generate_example()
|
|
|
|
def forward():
|
|
return self.op_f(*example_args, **example_kwargs)
|
|
|
|
trace = torch.jit.trace(forward, tuple())
|
|
graph = trace.graph
|
|
logging.debug("Graph for op %r: %s", self.op_f, graph)
|
|
|
|
# Track up from the return node and assume this is our op.
|
|
return_node = graph.return_node()
|
|
return_inputs = list(return_node.inputs())
|
|
assert len(return_inputs) == 1, "Expected one input return"
|
|
op_node = return_inputs[0].node()
|
|
op_inputs = list(op_node.inputs())
|
|
logging.debug("Found op node: %r", op_node)
|
|
|
|
# Meta-data about the source op.
|
|
self.op_kind = op_node.kind()
|
|
self.op_arity = len(op_inputs)
|
|
if self.operand_map is None:
|
|
self.operand_map = self._associate_inputs(op_inputs, example_args,
|
|
example_kwargs)
|
|
|
|
# Results.
|
|
op_outputs = list(op_node.outputs())
|
|
if self.result_map is None:
|
|
if self.is_outref_form:
|
|
# Only support single outref results.
|
|
assert len(op_outputs) == 1, (
|
|
"For outref ops, only a single output is supported")
|
|
self.result_map = [(0, TensorOutRef("result"))]
|
|
else:
|
|
# Map results in order.
|
|
self.result_map = []
|
|
|
|
def result_name(i):
|
|
if len(op_outputs) == 1:
|
|
return "result"
|
|
else:
|
|
return "result%d" % i
|
|
|
|
for i, op_output in enumerate(op_outputs):
|
|
op_output_type = op_output.type()
|
|
if issubclass(type(op_output_type), torch.TensorType):
|
|
self.result_map.append((i, TensorValue(result_name(i))))
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported op output type: {!r}".format(op_output_type))
|
|
return self
|
|
|
|
def _associate_inputs(self, op_inputs, example_args, example_kwargs):
|
|
"""Given inputs to a graph op node, associates to the input args.
|
|
|
|
This will match up example arguments with what was produced in the graph,
|
|
setting the operand_map.
|
|
|
|
Returns:
|
|
List of (input_index, ValueSpec) mapping inputs to the graph node to
|
|
provided values in the op definition.
|
|
"""
|
|
assert len(example_args) == len(self.op_args)
|
|
assert example_kwargs.keys() == self.op_kwargs.keys()
|
|
|
|
def find_arg(value):
|
|
for i, arg in enumerate(example_args):
|
|
if _is_same_value(arg, value):
|
|
return self.op_args[i]
|
|
for key, arg in example_kwargs.items():
|
|
if _is_same_value(arg, value):
|
|
return self.op_kwargs[key]
|
|
|
|
raise KeyError("Op input not in arguments: {!r} -> {!r}".format(
|
|
value, op_inputs))
|
|
|
|
operand_map = []
|
|
for i, op_input in enumerate(op_inputs):
|
|
input_node = op_input.node()
|
|
immediate_value = _extract_immediate(input_node)
|
|
if immediate_value is not None:
|
|
operand_map.append((i, find_arg(immediate_value)))
|
|
return operand_map
|
|
|
|
|
|
class OpRegistry:
|
|
"""Maintains a registry of op mappings."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._mappings = []
|
|
self._pending_mapping = None
|
|
|
|
def op(self, op_f, *op_args, **op_kwargs):
|
|
"""Forwards to the SimpleOpMapping constructor and adds it.
|
|
|
|
The mapping is not finalized until either the registry is finalized or the
|
|
next op mapping is added. This allows tweaks to the mapping to be done
|
|
inline prior to performing detailed introspection.
|
|
|
|
Returns:
|
|
The SimpleOpMapping instance.
|
|
"""
|
|
self._finalize_pending()
|
|
m = SimpleOpMapping(op_f, *op_args, **op_kwargs)
|
|
self._pending_mapping = m
|
|
return m
|
|
|
|
@property
|
|
def mappings(self) -> Sequence[OpMapping]:
|
|
"""Returns the list of OpMapping.
|
|
|
|
Returns:
|
|
Sequence of OpMapping concrete classes (most commonly SimpleOpMapping).
|
|
"""
|
|
self._finalize_pending()
|
|
return self._mappings
|
|
|
|
def _finalize_pending(self):
|
|
if not self._pending_mapping:
|
|
return
|
|
|
|
outref_mapping = None
|
|
pending_mapping = self._pending_mapping
|
|
self._pending_mapping = None
|
|
if pending_mapping.outref_variant_value:
|
|
# Generate a variant (with an out= form).
|
|
outref_mapping = pending_mapping.clone()
|
|
outref_mapping.op_kwargs["out"] = outref_mapping.outref_variant_value
|
|
outref_mapping.outref_variant_value = None
|
|
|
|
# Finalize the original.
|
|
pending_mapping.finalize()
|
|
self._mappings.append(pending_mapping)
|
|
|
|
# Finalize the outref form if generated.
|
|
if outref_mapping:
|
|
outref_mapping.finalize()
|
|
self._mappings.append(outref_mapping)
|