# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Base classes and interfaces for mapping ops to MLIR. The goal of this facility is to define the majority of op mappings by example, from the actual Python invocation, using the PyTorch tracer to extract its Graph IR, and introspecting that to determine mappings, metadata and types for corresponding MLIR definitions. This is not meant to cover everything, and it is expected that a number of hard ops should be mapped by hand. The result of building such an OpRegistry should be a data structure that can be used to generate ODS and tables for doing systematic imports from a PyTorch Graph to corresponding MLIR module. Example usage (fully automatic discovery): >>> r = OpRegistry() >>> r.op(torch.add, TensorValue("input"), TensorValue("other"), alpha=ScalarValue()).with_outref_variant() """ from typing import Dict, List, Optional, Sequence, Tuple import logging import random import torch __all__ = [ "SimpleOpMapping", "OpRegistry", "ScalarValue", "TensorOutRef", "TensorValue", ] def _is_same_value(value1, value2): # Tensors are considered via reference equality. value1_tensor = isinstance(value1, torch.Tensor) value2_tensor = isinstance(value2, torch.Tensor) if value1_tensor or value2_tensor: return value1 is value2 # Everything else is value equality. return value1 == value2 def _extract_immediate(node): """Extracts an immediate value from a node. Supported node types: prim::Constant prim::ListConstruct """ # Constant if node.kind() == "prim::Constant": # Try to extract as different types. try: return node.t("value") except RuntimeError: pass try: return node.f("value") except RuntimeError: pass try: return node.i("value") except RuntimeError: pass try: return node.s("value") except RuntimeError: pass return None # List elif node.kind() == "prim::ListConstruct": return [_extract_immediate(i.node()) for i in node.inputs()] else: raise ValueError("Unrecognized immediate input node: {!r}".format(node)) class ValueSpec: """Base class for inputs to operations. This binds information about how the input is mapped to the MLIR operation. """ def __init__(self, name=None): super().__init__() self.name = name @property def mlir_ods_predicate(self): return "AnyType" def generate_example(self, index=0): """Generates an example value.""" raise NotImplementedError() def __repr__(self): return "{}({!r})".format(self.__class__.__name__, self.name) class TensorValue(ValueSpec): """An input that is a tensor.""" def __init__(self, name=None, *, example_size=None): super().__init__(name=name) if example_size is None: example_size = (2, 3, 7) # No significance. self.example_size = example_size @property def mlir_ods_predicate(self): return "ATen_AnyTensor" def generate_example(self, index=0): return torch.rand(*self.example_size) class TensorOutRef(ValueSpec): """A tensor that is passed by ref as an out parameter.""" def __init__(self, name=None, *, example_size=None): super().__init__(name=name) if example_size is None: example_size = (2, 3, 7) # No significance. self.example_size = example_size @property def mlir_ods_predicate(self): return "ATen_AnyRefTensor" def generate_example(self, index=0): return torch.rand(*self.example_size) class ScalarValue(ValueSpec): """An input that is a scalar.""" def __init__(self, name=None, value=None): super().__init__(name=name) self.value = value @property def mlir_ods_predicate(self): return "ATen_AnyScalar" def generate_example(self, index=0): if self.value is not None: return self.value return 1.0 + index # Generates a stable value. class OpMapping: """Base class for things purporting to map an operation.""" pass class SimpleOpMapping(OpMapping): """Maps a PyTorch invocation to its MLIR representation.""" def __init__(self, op_f, *op_args, **op_kwargs): super().__init__() self.op_f = op_f self.op_args = op_args self.op_kwargs = op_kwargs self.outref_variant_value = None # type: Optional[TensorOutRef] # Set after finalize. self.op_kind = None # type: Optional[str] self.op_arity = -1 # type: int self.operand_map = None # type: Optional[List[Tuple[int, ValueSpec]]] self.result_map = None # type: Optional[List[Tuple[int, ValueSpec]]] self.mlir_operation_name = None # type: Optional[str] def __repr__(self): return ("SimpleOp({kind!r}[{arity}] -> {name!s}, operands={operands!r}, " "results={results!r})".format(kind=self.op_kind, arity=self.op_arity, name=self.mlir_operation_name, operands=self.operand_map, results=self.result_map)) def clone(self) -> "SimpleOpMapping": copy = SimpleOpMapping(self.op_f, *self.op_args, **self.op_kwargs) for name in [ "outref_variant_value", "op_kind", "op_arity", "operand_map", "result_map", "mlir_operation_name" ]: setattr(copy, name, getattr(self, name)) return copy def with_outref_variant(self, value=None): """Instructs the registry to also generate an outref variant. This is done by cloning the op prior to finalizing and adding an out= paramer. """ self.outref_variant_value = TensorOutRef() if value is None else value return self @property def all_arg_values(self) -> List[ValueSpec]: """Returns all arg values (either positional or kw).""" return list(self.op_args) + list(self.op_kwargs.values()) @property def is_outref_form(self) -> bool: """Whether the op contains an out parameter that aliases to the result.""" return any(isinstance(a, TensorOutRef) for a in self.all_arg_values) def generate_example(self) -> Tuple[Tuple, Dict]: """Generates an example signature for invoking the op. Returns: (tuple, dict) of positional and keyword args. """ index = 0 positional = list() kw = dict() for op_arg in self.op_args: positional.append(op_arg.generate_example(index)) index += 1 for kw_name, kw_value in self.op_kwargs.items(): kw[kw_name] = kw_value.generate_example(index) index += 1 return positional, kw def finalize(self): """Finalizes the mapping once all hints have been applied.""" # Update the name on all args if undefined. for index, op_arg in enumerate(self.op_args): if op_arg.name is None: op_arg.name = "arg%d".format(index) for key, op_arg in self.op_kwargs.items(): if op_arg.name is None: op_arg.name = key # Create an example graph and configure from it. self._configure_from_example() # Determine default operation name. if self.mlir_operation_name is None: self._set_default_mlir_operation_name() def _set_default_mlir_operation_name(self): op_ns, op_name = self.op_kind.split("::", maxsplit=1) # Since these are emitted into the "aten" dialect namespace, alias them # to omit the prefix to distinguish from custom ops and others (which will # have a prefix). default_name = op_name if op_ns == "aten" else op_ns + "." + op_name if op_ns == "aten": default_name = op_name else: default_name = op_ns + "." + op_name if self.is_outref_form: default_name += ".inplace" self.mlir_operation_name = default_name def _configure_from_example(self): # Trace the op so that we get an example graph like this: # %0 : Float(2, 3, 7) = prim::Constant[value=]() # %1 : Float(2, 3, 7) = prim::Constant[value=]() # %2 : float = prim::Constant[value=3.]() # %3 : Float(2, 3, 7) = aten::add(%0, %1, %2) # return (%3) # The def of the return value is expected to be the modeled op. The # inputs to that op are expected to be captured constants that can be # re-associated to the example inputs. example_args, example_kwargs = self.generate_example() def forward(): return self.op_f(*example_args, **example_kwargs) trace = torch.jit.trace(forward, tuple()) graph = trace.graph logging.debug("Graph for op %r: %s", self.op_f, graph) # Track up from the return node and assume this is our op. return_node = graph.return_node() return_inputs = list(return_node.inputs()) assert len(return_inputs) == 1, "Expected one input return" op_node = return_inputs[0].node() op_inputs = list(op_node.inputs()) logging.debug("Found op node: %r", op_node) # Meta-data about the source op. self.op_kind = op_node.kind() self.op_arity = len(op_inputs) if self.operand_map is None: self.operand_map = self._associate_inputs(op_inputs, example_args, example_kwargs) # Results. op_outputs = list(op_node.outputs()) if self.result_map is None: if self.is_outref_form: # Only support single outref results. assert len(op_outputs) == 1, ( "For outref ops, only a single output is supported") self.result_map = [(0, TensorOutRef("result"))] else: # Map results in order. self.result_map = [] def result_name(i): if len(op_outputs) == 1: return "result" else: return "result%d" % i for i, op_output in enumerate(op_outputs): op_output_type = op_output.type() if issubclass(type(op_output_type), torch.TensorType): self.result_map.append((i, TensorValue(result_name(i)))) else: raise ValueError( "Unsupported op output type: {!r}".format(op_output_type)) return self def _associate_inputs(self, op_inputs, example_args, example_kwargs): """Given inputs to a graph op node, associates to the input args. This will match up example arguments with what was produced in the graph, setting the operand_map. Returns: List of (input_index, ValueSpec) mapping inputs to the graph node to provided values in the op definition. """ assert len(example_args) == len(self.op_args) assert example_kwargs.keys() == self.op_kwargs.keys() def find_arg(value): for i, arg in enumerate(example_args): if _is_same_value(arg, value): return self.op_args[i] for key, arg in example_kwargs.items(): if _is_same_value(arg, value): return self.op_kwargs[key] raise KeyError("Op input not in arguments: {!r} -> {!r}".format( value, op_inputs)) operand_map = [] for i, op_input in enumerate(op_inputs): input_node = op_input.node() immediate_value = _extract_immediate(input_node) if immediate_value is not None: operand_map.append((i, find_arg(immediate_value))) return operand_map class OpRegistry: """Maintains a registry of op mappings.""" def __init__(self): super().__init__() self._mappings = [] self._pending_mapping = None def op(self, op_f, *op_args, **op_kwargs): """Forwards to the SimpleOpMapping constructor and adds it. The mapping is not finalized until either the registry is finalized or the next op mapping is added. This allows tweaks to the mapping to be done inline prior to performing detailed introspection. Returns: The SimpleOpMapping instance. """ self._finalize_pending() m = SimpleOpMapping(op_f, *op_args, **op_kwargs) self._pending_mapping = m return m @property def mappings(self) -> Sequence[OpMapping]: """Returns the list of OpMapping. Returns: Sequence of OpMapping concrete classes (most commonly SimpleOpMapping). """ self._finalize_pending() return self._mappings def _finalize_pending(self): if not self._pending_mapping: return outref_mapping = None pending_mapping = self._pending_mapping self._pending_mapping = None if pending_mapping.outref_variant_value: # Generate a variant (with an out= form). outref_mapping = pending_mapping.clone() outref_mapping.op_kwargs["out"] = outref_mapping.outref_variant_value outref_mapping.outref_variant_value = None # Finalize the original. pending_mapping.finalize() self._mappings.append(pending_mapping) # Finalize the outref form if generated. if outref_mapping: outref_mapping.finalize() self._mappings.append(outref_mapping)