Representing Symbolic Shape Expressions in Torch Dialect (#3372)

Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
pull/3430/head
Sambhav Jain 2024-06-07 04:04:03 -07:00 committed by GitHub
parent 431d98b405
commit d0a818a03e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 996 additions and 9 deletions

View File

@ -11,6 +11,7 @@
#define TORCH_OPS #define TORCH_OPS
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td" include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/CastInterfaces.td"
@ -1337,4 +1338,67 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes",
let hasVerifier = 1; let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===//
// Symbolic shape modeling ops for TorchDynamo frontend.
//===----------------------------------------------------------------------===//
def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> {
let summary = "Symbolic int representing a dynamic dimension";
let description = [{
The `torch.symbolic_int` operation captures a dynamic dimension on the
global function arguments as exported by TorchDynamo (torch.export).
It associates the shape symbols (i.e. "s0", "s1") with the
global SSA values (i.e. `%0`, `%1`) that is then referenced
to bind shapes on op results.
Additionally, the operation annotates `min_val` and `max_val` attributes
denoting the range constraints for the dynamic dimension. This may be
useful for modeling runtime shape guards, or compile-time optimizations
based on the shape bounds (min, opt, max) on results of ops / regions.
Example:
```
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int
```
}];
let arguments = (ins
StrAttr:$symbol_name,
I64Attr:$min_val,
I64Attr:$max_val
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = [{
$symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result)
}];
}
def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> {
let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols";
let description = [{
The `torch.bind_symbolic_shape` operation binds shape expressions
useful to compute the dynamic dimensions of a tensor. It takes a
variadic of SSA symbols that map 1:1 to the local symbols declared
in the affine map. The affine map contains a list of affine shape
expressions for each dim where the terminals are from the declared
symbols.
Example:
```
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
```
}];
let arguments = (ins
Torch_ValueTensorType:$operand,
Variadic<Torch_IntType>:$shape_symbols,
Builtin_AffineMapAttr:$shape_expressions
);
let results = (outs);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
#endif // TORCH_OPS #endif // TORCH_OPS

View File

@ -5034,3 +5034,65 @@ LogicalResult InitializeGlobalSlotsOp::verify() {
return emitOpError("expected number of operands to match number of slots"); return emitOpError("expected number of operands to match number of slots");
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// BindSymbolicShapeOp
//===----------------------------------------------------------------------===//
//
// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] ->
// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
//
ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand operand;
SmallVector<OpAsmParser::UnresolvedOperand> shapeSymbols;
AffineMapAttr shapeExpressions;
Type operandType;
if (parser.parseOperand(operand) || parser.parseComma() ||
parser.parseLSquare() || parser.parseOperandList(shapeSymbols) ||
parser.parseRSquare() || parser.parseComma() ||
parser.parseAttribute(shapeExpressions, "shape_expressions",
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(operandType)) {
return failure();
}
if (parser.resolveOperand(operand, operandType, result.operands) ||
parser.resolveOperands(shapeSymbols,
parser.getBuilder().getType<Torch::IntType>(),
result.operands)) {
return failure();
}
return success();
}
// Use a custom printer here to avoid the AffineMap from getting hoisted
// when printed. This makes it so the AffineMap is printed inline with the op.
void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
p << " " << getOperand() << ", [";
llvm::interleaveComma(getShapeSymbols(), p);
p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">";
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/{"shape_expressions"});
p << " : " << getOperand().getType();
}
LogicalResult BindSymbolicShapeOp::verify() {
if (getShapeSymbols().empty())
return emitOpError() << "requires non-empty shapeSymbols";
for (auto symbol : getShapeSymbols()) {
Operation *definingOp = symbol.getDefiningOp();
if (!isa<SymbolicIntOp>(definingOp)) {
return emitOpError()
<< "shape symbol must be produced by a SymbolicIntOp";
}
}
return success();
}

View File

@ -49,6 +49,9 @@ class FxImporterTestConfig(TestConfig):
prog, prog,
output_type=self._output_type, output_type=self._output_type,
func_name=artifact.__class__.__name__, func_name=artifact.__class__.__name__,
# While the current e2e tests don't exercise symbolic shapes,
# enabling this here ensures they don't regress either.
import_symbolic_shape_expressions=True,
) )
module = self._backend.compile(module) module = self._backend.compile(module)
backend_module = self._backend.load(module) backend_module = self._backend.load(module)

View File

@ -14,6 +14,8 @@ except ImportError:
import logging import logging
import operator import operator
import re import re
import sympy
import math
from dataclasses import dataclass from dataclasses import dataclass
from types import BuiltinMethodType, BuiltinFunctionType from types import BuiltinMethodType, BuiltinFunctionType
from typing import ( from typing import (
@ -81,6 +83,14 @@ from torch.fx.node import (
) )
from ..ir import ( from ..ir import (
AffineAddExpr,
AffineConstantExpr,
AffineExpr,
AffineMap,
AffineMapAttr,
AffineModExpr,
AffineMulExpr,
AffineSymbolExpr,
Attribute, Attribute,
Block, Block,
Context, Context,
@ -258,6 +268,71 @@ else:
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
@dataclass
class RangeConstraint:
min_val: int
max_val: int
def sympy_expr_to_semi_affine_expr(
expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr]
) -> AffineExpr:
"""Translate sympy expressions to MLIR (semi-)affine expressions.
Recursively traverse the sympy expr AST and build the affine expr.
This is not a perfect translation. Sympy expressions are much more
expressive and not as constrained as affine (linear) expressions are.
However, for the most part, we don't need to support all of sympy.
PyTorch only uses a subset of sympy for capturing and expressing
symbolic shapes, and among what's supported, we expect the semi-affine
expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps)
to be sufficient.
"""
if isinstance(expr, sympy.Symbol):
return symbols_map[str(expr)]
elif isinstance(expr, (int, sympy.Integer)):
return AffineConstantExpr.get(expr)
# This handles both add (`s0 + c`) and subtract (`s0 - c`).
# The expression is `sympy.Add` in both cases but with args
# (s0, c) in first case and (s0, -c) in the second case.
elif isinstance(expr, sympy.Add):
affine_expr = AffineConstantExpr.get(0)
for arg in expr.args:
affine_expr = AffineAddExpr.get(
affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map)
)
return affine_expr
elif isinstance(expr, sympy.Mul):
affine_expr = AffineConstantExpr.get(1)
for arg in expr.args:
affine_expr = AffineMulExpr.get(
affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map)
)
return affine_expr
elif isinstance(expr, sympy.Pow):
base, exp = expr.args
# Only integer exponent is supported
# So, s1 ** s0 isn't allowed.
assert isinstance(exp, (int, sympy.Integer))
assert exp > 0, "Only positive exponents supported in sympy.Pow"
affine_expr = AffineConstantExpr.get(1)
for _ in range(exp):
affine_expr = AffineMulExpr.get(
affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map)
)
return affine_expr
elif isinstance(expr, sympy.Mod):
dividend, divisor = expr.args
return AffineModExpr.get(
sympy_expr_to_semi_affine_expr(dividend, symbols_map),
sympy_expr_to_semi_affine_expr(divisor, symbols_map),
)
else:
raise NotImplementedError(
f"Translation of sympy.Expr of type {type(expr)} not implemented yet."
)
@dataclass(frozen=True) @dataclass(frozen=True)
class SparsityMeta: class SparsityMeta:
""" """
@ -478,6 +553,7 @@ class FxImporter:
*, *,
func_name: str = "main", func_name: str = "main",
func_visibility: Optional[str] = None, func_visibility: Optional[str] = None,
import_symbolic_shape_expressions: bool = False,
) -> Operation: ) -> Operation:
"""Imports an ExportedProgram according to our chosen canonical representation. """Imports an ExportedProgram according to our chosen canonical representation.
@ -527,6 +603,10 @@ class FxImporter:
sig = prog.graph_signature sig = prog.graph_signature
# Populate symbolic guards for dynamic shapes (if any)
if import_symbolic_shape_expressions:
self._cc.set_symbolic_guards(prog)
# Invert the (producer, node_name) maps for mutated user inputs and mutated # Invert the (producer, node_name) maps for mutated user inputs and mutated
# buffers. This is because we hit-detect based on the input node name. # buffers. This is because we hit-detect based on the input node name.
mutated_user_inputs = { mutated_user_inputs = {
@ -682,7 +762,9 @@ class FxImporter:
# Import all nodes and return. # Import all nodes and return.
node_importer.import_nodes( node_importer.import_nodes(
all_producer_nodes.values(), skip_placeholders_outputs=True all_producer_nodes.values(),
skip_placeholders_outputs=True,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
) )
node_importer.return_node_values(loc, user_outputs) node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op) self.symbol_table.insert(func_op)
@ -694,6 +776,7 @@ class FxImporter:
*, *,
func_name: str = "main", func_name: str = "main",
func_visibility: Optional[str] = None, func_visibility: Optional[str] = None,
import_symbolic_shape_expressions: bool = False,
) -> Operation: ) -> Operation:
"""Imports a consolidated torch.export.ExportedProgram instance. """Imports a consolidated torch.export.ExportedProgram instance.
@ -728,6 +811,10 @@ class FxImporter:
state_dict = prog.state_dict state_dict = prog.state_dict
arg_replacements: Dict[str, Any] = {} arg_replacements: Dict[str, Any] = {}
# Populate symbolic guards for dynamic shapes (if any)
if import_symbolic_shape_expressions:
self._cc.set_symbolic_guards(prog)
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
if hasattr(prog, "constants"): if hasattr(prog, "constants"):
@ -774,7 +861,10 @@ class FxImporter:
g.erase_node(node) g.erase_node(node)
return self.import_stateless_graph( return self.import_stateless_graph(
g, func_name=func_name, func_visibility=func_visibility g,
func_name=func_name,
func_visibility=func_visibility,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
) )
def import_graph_module(self, gm: GraphModule) -> Operation: def import_graph_module(self, gm: GraphModule) -> Operation:
@ -791,6 +881,7 @@ class FxImporter:
*, *,
func_name: str = "main", func_name: str = "main",
func_visibility: Optional[str] = None, func_visibility: Optional[str] = None,
import_symbolic_shape_expressions: bool = False,
) -> Operation: ) -> Operation:
"""Low-level import of a functionalized, assumed stateless Graph as a func. """Low-level import of a functionalized, assumed stateless Graph as a func.
@ -815,7 +906,9 @@ class FxImporter:
self._cc, self._cc,
entry_block, entry_block,
) )
node_importer.import_nodes(g.nodes) node_importer.import_nodes(
g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions
)
self.symbol_table.insert(func) self.symbol_table.insert(func)
return func return func
@ -870,6 +963,7 @@ class ContextCache:
"_c", "_c",
"_dtype_to_type", "_dtype_to_type",
"_tensor_metadata_cache", "_tensor_metadata_cache",
"_symbolic_guards",
"_py_attr_tracker", "_py_attr_tracker",
# Types. # Types.
"torch_bool_type", "torch_bool_type",
@ -888,6 +982,7 @@ class ContextCache:
self._tensor_metadata_cache: Dict[ self._tensor_metadata_cache: Dict[
Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType
] = {} ] = {}
self._symbolic_guards: Dict = {}
self._py_attr_tracker = py_attr_tracker or RefTracker() self._py_attr_tracker = py_attr_tracker or RefTracker()
# Common types. # Common types.
@ -1037,6 +1132,52 @@ class ContextCache:
return Location.file(filename, line, col=0, context=self._c) return Location.file(filename, line, col=0, context=self._c)
return Location.unknown(context=self._c) return Location.unknown(context=self._c)
def set_symbolic_guards(
self, prog: torch.export.ExportedProgram
) -> Dict[str, RangeConstraint]:
def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
# Convert simple sympy Integers into concrete int
if val == sympy.oo:
return math.inf
if val == -sympy.oo:
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
# TODO: Remove this adjustment when fractional ranges are removed
return adjust_func(val)
contains_symbolic_ints = False
for val in prog.range_constraints.values():
if (
isinstance(val.lower, sympy.Integer)
and isinstance(val.upper, sympy.Integer)
and not val.is_bool
):
contains_symbolic_ints = True
break
if contains_symbolic_ints:
# Build a map from shape symbol name to `RangeConstraint` object
# capturing `min_val`` and `max_val`` constraints for that
# symbol. Translate sympy integers to regular integers.
#
# Example:
# {
# 's0': RangeConstraint(min_val=5, max_val=10),
# 's1': RangeConstraint(min_val=0, max_val=100),
# 's3': RangeConstraint(min_val=0, max_val=9223372036854775806),
# }
self._symbolic_guards = {
str(k): RangeConstraint(
_sympy_int_to_int(v.lower, math.ceil),
_sympy_int_to_int(v.upper, math.floor),
)
for k, v in prog.range_constraints.items()
}
def get_symbolic_guards(self) -> Dict[str, RangeConstraint]:
return self._symbolic_guards
class GraphNodeImporter: class GraphNodeImporter:
"""Imports graph nodes into an MLIR function. """Imports graph nodes into an MLIR function.
@ -1050,6 +1191,7 @@ class GraphNodeImporter:
"_cc", "_cc",
"_on_node_produced", "_on_node_produced",
"_v", "_v",
"_symbol_to_value",
"_multi_result_nodes", "_multi_result_nodes",
"fx_importer", "fx_importer",
] ]
@ -1068,6 +1210,8 @@ class GraphNodeImporter:
# Map of (Node, result_index) to MLIR Value or a callback that lazily # Map of (Node, result_index) to MLIR Value or a callback that lazily
# constructs and returns a value. # constructs and returns a value.
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
# Map of Shape Symbol to MLIR Value
self._symbol_to_value: Dict[str, Value] = {}
# Map of node name to hook that should be called when it is produced. # Map of node name to hook that should be called when it is produced.
self._on_node_produced: Dict[str, Callable[[Value], None]] = {} self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
# Statically multi-result nodes which we have de-tupled are noted here. # Statically multi-result nodes which we have de-tupled are noted here.
@ -1108,6 +1252,28 @@ class GraphNodeImporter:
self._v[key] = value self._v[key] = value
return value return value
def bind_symbol_value(
self,
shape_symbol: str,
value: Value,
):
"""Binds a shape symbol to a global SSA value (and asserts if already bound)."""
assert (
shape_symbol not in self._symbol_to_value
), f"Symbol already has a value: {shape_symbol}"
self._symbol_to_value[shape_symbol] = value
def resolve_symbol_value(self, shape_symbol: str) -> Value:
"""Resolves a shape symbol to a value."""
try:
binding = self._symbol_to_value[shape_symbol]
except KeyError:
raise KeyError(
f"Shape symbol {shape_symbol} has not been bound to an MLIR value"
)
if isinstance(binding, Value):
return binding
def import_mutable_to_vtensor( def import_mutable_to_vtensor(
self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str
) -> Value: ) -> Value:
@ -1190,10 +1356,20 @@ class GraphNodeImporter:
func_dialect.ReturnOp(operands, loc=loc) func_dialect.ReturnOp(operands, loc=loc)
def import_nodes( def import_nodes(
self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False self,
nodes: Iterable[Node],
*,
skip_placeholders_outputs: bool = False,
import_symbolic_shape_expressions: bool = False,
): ):
with InsertionPoint(self._b): with InsertionPoint(self._b):
loc = Location.unknown() loc = Location.unknown()
# Import dynamic shape symbols and guards (if any)
if import_symbolic_shape_expressions:
symbolic_guards = self._cc.get_symbolic_guards()
self._import_shape_symbols_with_guards(loc, symbolic_guards)
num_placeholders = 0 num_placeholders = 0
for node in nodes: for node in nodes:
op = node.op op = node.op
@ -1253,6 +1429,8 @@ class GraphNodeImporter:
operands = [self._import_argument(loc, arg) for arg in node.args[0]] operands = [self._import_argument(loc, arg) for arg in node.args[0]]
func_dialect.ReturnOp(operands, loc=loc) func_dialect.ReturnOp(operands, loc=loc)
self._create_bind_symbolic_shape_ops(loc, node)
def _promote_symbolic_scalar_int_float(self, loc, graph, param): def _promote_symbolic_scalar_int_float(self, loc, graph, param):
temp_target = torch.ops.aten.Float.Scalar temp_target = torch.ops.aten.Float.Scalar
temp_node = Node( temp_node = Node(
@ -1516,6 +1694,69 @@ class GraphNodeImporter:
for i, value in enumerate(operation.results): for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i) self.bind_node_value(node, value, i)
def _import_shape_symbols_with_guards(
self, loc: Location, symbolic_guards: Dict[str, RangeConstraint]
):
for symbol, constraints in symbolic_guards.items():
# Create torch.sym_int ops
operation = Operation.create(
name="torch.symbolic_int",
attributes={
"symbol_name": StringAttr.get(symbol),
"min_val": self._cc.integer_attr(constraints.min_val, 64),
"max_val": self._cc.integer_attr(constraints.max_val, 64),
},
results=[self._cc.torch_int_type],
loc=loc,
)
self.bind_symbol_value(symbol, operation.result)
def _create_bind_symbolic_shape_ops(self, loc: Location, node: torch_fx.Node):
node_val = node.meta.get("val")
if (node_val is not None) and isinstance(node_val, TorchFakeTensor):
# Only create bind ops if the shapes contain symbolic sizes.
# Query the bool attribute `_has_symbolic_sizes_strides` on node.meta["val"].
if node_val._has_symbolic_sizes_strides:
# Read node metadata to obtain shape symbols and expressions
symbols_set = set()
shape_exprs = []
for s in node_val.size():
if isinstance(s, torch.SymInt):
symbols_set.update(s.node.expr.free_symbols)
shape_exprs.append(s.node.expr)
else:
assert isinstance(s, int)
shape_exprs.append(s)
# Map from sympy shape symbols to local symbols in the affine map
symbols_set = sorted(symbols_set, key=lambda x: x.name)
symbols_map = {
str(symbol): AffineSymbolExpr.get(i)
for i, symbol in enumerate(symbols_set)
}
# Convert symbolic shape expressions into affine expressions
affine_exprs = [
sympy_expr_to_semi_affine_expr(expr, symbols_map)
for expr in shape_exprs
]
affine_map = AffineMap.get(0, len(symbols_set), affine_exprs)
# Build operand list
operand_list = []
operand_list.append(self.resolve_node_value(node))
for symbol in symbols_map.keys():
operand_list.append(self.resolve_symbol_value(symbol))
# Create torch.bind_symbolic_shape ops
Operation.create(
name="torch.bind_symbolic_shape",
attributes={"shape_expressions": AffineMapAttr.get(affine_map)},
operands=operand_list,
loc=loc,
)
def _import_argument( def _import_argument(
self, loc: Location, arg: NodeArgument, expected_jit_type=None self, loc: Location, arg: NodeArgument, expected_jit_type=None
) -> Value: ) -> Value:

View File

@ -54,6 +54,7 @@ def export_and_import(
fx_importer: Optional[FxImporter] = None, fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False, experimental_support_mutation: bool = False,
import_symbolic_shape_expressions: bool = False,
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main", func_name: str = "main",
@ -79,9 +80,17 @@ def export_and_import(
if experimental_support_mutation: if experimental_support_mutation:
if torch.__version__ < "2.3.0.dev20240207": if torch.__version__ < "2.3.0.dev20240207":
warnings.warn("Mutable program import only supported on PyTorch 2.3+") warnings.warn("Mutable program import only supported on PyTorch 2.3+")
fx_importer.import_program(prog, func_name=func_name) fx_importer.import_program(
prog,
func_name=func_name,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
)
else: else:
fx_importer.import_frozen_program(prog, func_name=func_name) fx_importer.import_frozen_program(
prog,
func_name=func_name,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
)
return _module_lowering( return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module enable_ir_printing, OutputType.get(output_type), fx_importer.module

View File

@ -3026,3 +3026,35 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
%1 = torch.copy.to_tensor %0 : !torch.tensor %1 = torch.copy.to_tensor %0 : !torch.tensor
return %1 : !torch.tensor return %1 : !torch.tensor
} }
// -----
// CHECK-LABEL: @torch.symbolic_int$canonicalize(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
// CHECK-NOT: %[[S1:.*]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
// CHECK: %[[V1:.*]] = torch.aten.slice.Tensor %[[ARG1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: torch.bind_symbolic_shape %[[V1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
// CHECK: %[[V2:.*]] = torch.aten.add.Tensor %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: torch.bind_symbolic_shape %[[V2]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
// CHECK: return %[[V2]] : !torch.vtensor<[?],f32>
func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
%1 = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int9223372036854775807 = torch.constant.int 9223372036854775807
%int1_0 = torch.constant.int 1
%2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int9223372036854775807, %int1_0 : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
torch.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
%int1_1 = torch.constant.int 1
%3 = torch.aten.add.Tensor %arg0, %2, %int1_1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %3 : !torch.vtensor<[?],f32>
}

View File

@ -375,3 +375,22 @@ func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64]
// expected-error @+1 {{invalid sparsity encoding attribute}} // expected-error @+1 {{invalid sparsity encoding attribute}}
func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
// -----
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
// expected-error @+1 {{op requires non-empty shapeSymbols}}
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}
// -----
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}

View File

@ -89,6 +89,11 @@ def test_import_frozen_exported_program_with_func_name():
@run @run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes(): def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module): class Basic(nn.Module):
def __init__(self): def __init__(self):
@ -100,7 +105,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
batch = Dim("batch") batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}} dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import( m = fx.export_and_import(
Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net" Basic(),
torch.randn(3, 4),
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
) )
print(m) print(m)
@ -108,6 +117,12 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
@run @run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes # CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: torch.aten.size.int
# CHECK: torch.prim.ListConstruct
# CHECK: %[[EXPAND:.*]] = torch.aten.expand
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes(): def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module): class Basic(nn.Module):
def __init__(self): def __init__(self):
@ -127,7 +142,12 @@ def test_broadcast_with_dynamic_shapes():
} }
m = fx.export_and_import( m = fx.export_and_import(
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" Basic(),
x,
y,
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
) )
print(m) print(m)

View File

@ -0,0 +1,463 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests of various op special forms that the fx_importer
# handles.
import torch
import torch.export
import torch.nn as nn
from torch.export import Dim
from torch_mlir import fx
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_tanh_sigmoid_cat
# CHECK: func.func @main(
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32>
def test_tanh_sigmoid_cat():
class TanhSigmoidCat(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y, z):
a = torch.tanh(x)
b = torch.sigmoid(y)
return torch.cat((a, a, b, z), dim=1)
# Sample inputs
x = torch.randn(5, 2, 3)
y = torch.randn(5, 6, 3)
z = torch.randn(5, 4, 3)
# Dynamic dim constraints
dim_n = Dim("n", min=5, max=10)
dim_x1 = Dim("x1", max=100)
dim_y1 = Dim("y1", max=50)
dim_z1 = Dim("z1")
dynamic_shapes = {
"x": {0: dim_n, 1: dim_x1},
"y": {0: dim_n, 1: dim_y1},
"z": {0: dim_n, 1: dim_z1},
}
m = fx.export_and_import(
TanhSigmoidCat(),
x,
y,
z,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_symbolic_dim_differ_by_one
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[ARG0]], %[[SLICE]], {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[ADD]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: return %[[ADD]] : !torch.vtensor<[?],f32>
def test_symbolic_dim_differ_by_one():
class SymbolicDimDifferByOne(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y[1:]
# Sample inputs
x = torch.randn(5)
y = torch.randn(6)
# Dynamic dim constraints
dimx = Dim("dimx", min=3, max=6)
dimy = dimx + 1
dynamic_shapes = {
"x": {0: dimx},
"y": {0: dimy},
}
m = fx.export_and_import(
SymbolicDimDifferByOne(),
x,
y,
dynamic_shapes=dynamic_shapes,
experimental_support_mutation=True,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_outer_with_squared_shape
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[VIEW1:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
# CHECK: torch.bind_symbolic_shape %[[VIEW1]], [%[[S0]]], affine_map<()[s0] -> (s0, 1)> : !torch.vtensor<[?,1],f32>
# CHECK: %[[MUL:.+]] = torch.aten.mul.Tensor %[[VIEW1]], %[[ARG0]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
# CHECK: torch.bind_symbolic_shape %[[MUL]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
# CHECK: %[[VIEW2:.+]] = torch.aten.view %[[MUL]], {{.*}} : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[VIEW2]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
# CHECK: return %[[VIEW2]] : !torch.vtensor<[?],f32>
def test_outer_with_squared_shape():
class OuterWithSquaredShape(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.outer(x, x).flatten()
# Sample inputs
x = torch.rand(10)
# Dynamic dim constraints
batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
OuterWithSquaredShape(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_slice_tensor_static_output
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[2,1],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: %[[SLICE1:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
# CHECK: %[[SLICE2:.+]] = torch.aten.slice.Tensor %[[SLICE1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
# CHECK: return %[[SLICE2]] : !torch.vtensor<[2,1],f32>
def test_slice_tensor_static_output():
class SliceTensorStaticOutput(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x[0:2, :1]
# Sample inputs
x = torch.randn(4, 3)
# Dynamic dim constraints
batch = Dim("batch", min=3)
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
SliceTensorStaticOutput(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_slice_tensor_dynamic_output
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 9223372036854775806} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0 - 5)> : !torch.vtensor<[?],f32>
# CHECK: return %[[SLICE]] : !torch.vtensor<[?],f32>
def test_slice_tensor_dynamic_output():
class SliceTensorDynamicOutput(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x[5:]
# Sample inputs
x = torch.randn(10)
# Dynamic dim constraints
dimx = Dim("dimx", min=5)
dynamic_shapes = {"x": {0: dimx}}
m = fx.export_and_import(
SliceTensorDynamicOutput(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_div_tensor_mixed_ranks
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: %[[DIV:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[DIV]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: return %[[DIV]] : !torch.vtensor<[?,3],f32>
def test_div_tensor_mixed_ranks():
class DivTensorMixedRanks(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
div = torch.div(x, y)
return div
# Sample inputs
x = torch.tensor(10.0)
y = torch.randn(2, 3)
# Dynamic dim constraints
batch = Dim("batch")
dynamic_shapes = {"x": None, "y": {0: batch}}
m = fx.export_and_import(
DivTensorMixedRanks(),
x,
y,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_shape_div
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> {
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32>
# CHECK: %[[VIEW:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?,7],f32>, !torch.list<int> -> !torch.vtensor<[?,5],f32>
# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S1]]], affine_map<()[s0] -> (s0 * 7, 5)> : !torch.vtensor<[?,5],f32>
# CHECK: return %[[VIEW]] : !torch.vtensor<[?,5],f32>
def test_shape_div():
class ShapeDiv(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 5)
# Sample inputs
x = torch.rand(10, 7)
# Dynamic dim constraints
batch = Dim("batch", max=1000) * 5
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
ShapeDiv(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>) -> !torch.vtensor<[3,?],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32>
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,?],f32>
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (3, s0)> : !torch.vtensor<[3,?],f32>
# CHECK: return %[[EXPAND]] : !torch.vtensor<[3,?],f32>
def test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic():
class BroadcastUnitDimToStaticWithUnchangedDimDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.broadcast_to(x, (3, -1))
# Sample inputs
x = torch.randn(1, 2)
# Dynamic dim constraints
dim_1 = Dim("dim_1")
dynamic_shapes = {"x": {1: dim_1}}
m = fx.export_and_import(
BroadcastUnitDimToStaticWithUnchangedDimDynamic(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,2],f32>
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
# CHECK: return %3 : !torch.vtensor<[?,2],f32>
def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static():
class BroadcastUnitDimToDynamicWithUnchangedDimStatic(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
# Dynamic dim constraints
dim_0 = Dim("dim_0")
dynamic_shapes = {"x": {}, "y": {0: dim_0}}
m = fx.export_and_import(
BroadcastUnitDimToDynamicWithUnchangedDimStatic(),
x,
y,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S1]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?],f32>
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s1, s0)> : !torch.vtensor<[?,?],f32>
# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,?],f32>
def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic():
class BroadcastUnitDimToDynamicWithUnchangedDimDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
# Dynamic dim constraints
dim_0 = Dim("dim_0")
dim_1 = Dim("dim_1")
dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}}
m = fx.export_and_import(
BroadcastUnitDimToDynamicWithUnchangedDimDynamic(),
x,
y,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_rank_increase
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3,2],f32>) -> !torch.vtensor<[?,3,2],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32>
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,3,2],f32>
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32>
# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,3,2],f32>
def test_broadcast_unit_dim_to_dynamic_with_rank_increase():
class BroadcastUnitDimToDynamicWithRankIncrease(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.broadcast_to(x, y.size())
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(4, 3, 2)
# Dynamic dim constraints
dim_0 = Dim("dim_0")
dynamic_shapes = {"x": {}, "y": {0: dim_0}}
m = fx.export_and_import(
BroadcastUnitDimToDynamicWithRankIncrease(),
x,
y,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)
@run
# CHECK-LABEL: test_gather_elements
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> {
# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32>
# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32>
def test_gather_elements():
class GatherElements(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.gather(x, 0, y)
# Sample inputs
x = torch.randn(4, 3)
y = torch.tensor([[0, 0, 0], [1, 1, 1]])
# Dynamic dim constraints
batch = Dim("batch", min=3)
dynamic_shapes = {"x": {0: batch}, "y": {}}
m = fx.export_and_import(
GatherElements(),
x,
y,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)

View File

@ -0,0 +1,69 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
# This file contains tests checking translating sympy expressions to (semi-)affine expressions.
from sympy import Symbol
from torch_mlir.extras.fx_importer import sympy_expr_to_semi_affine_expr
from torch_mlir.ir import (
AffineSymbolExpr,
Context,
)
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_sympy_to_semi_affine_expr_translation
def test_sympy_to_semi_affine_expr_translation():
with Context():
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
symbols_set = sorted({s0, s1}, key=lambda x: x.name)
symbols_map = {
str(symbol): AffineSymbolExpr.get(i) for i, symbol in enumerate(symbols_set)
}
SYMPY_EXPRS = [
# CHECK: 10
(10),
# CHECK: s0
(s0),
# CHECK: s0
(s0 + 0),
# CHECK: s0 + 1
(s0 + 1),
# CHECK: s0
(s0 * 1),
# CHECK: s0 * 2
(s0 * 2),
# CHECK: s0 * s0
(s0 * s0),
# CHECK: s0 * s1
(s0 * s1),
# CHECK: s0 * s0
(s0**2),
# CHECK: (s0 * s0) * s0
(s0**3),
# CHECK: ((((s0 * s0) * s0) * s0) * s0) * s0
((s0**2) ** 3),
# CHECK: ((((((s0 * s0) * s0) * s0) * s0) * s0) * s0) * s0
(s0 ** (2**3)),
# CHECK: s0 mod 10
(s0 % 10),
# CHECK: s0 - s1 * 2 + 5
(s0 + 5 - 2 * s1),
]
for expr in SYMPY_EXPRS:
print(sympy_expr_to_semi_affine_expr(expr, symbols_map))

View File

@ -36,8 +36,13 @@ def test_scalar_typed_node():
x = x + 1.0 x = x + 1.0
return x.shape[0] return x.shape[0]
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %arg0, [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int # CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int
m = fx.export_and_import( m = fx.export_and_import(
Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}} Basic(),
torch.randn(3, 4),
dynamic_shapes={"x": {0: torch.export.Dim("b")}},
import_symbolic_shape_expressions=True,
) )
print(m) print(m)