diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 65f514c2e..035632878 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -11,6 +11,7 @@ #define TORCH_OPS include "torch-mlir/Dialect/Torch/IR/TorchTypes.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -1337,4 +1338,67 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Symbolic shape modeling ops for TorchDynamo frontend. +//===----------------------------------------------------------------------===// + +def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> { + let summary = "Symbolic int representing a dynamic dimension"; + let description = [{ + The `torch.symbolic_int` operation captures a dynamic dimension on the + global function arguments as exported by TorchDynamo (torch.export). + It associates the shape symbols (i.e. "s0", "s1") with the + global SSA values (i.e. `%0`, `%1`) that is then referenced + to bind shapes on op results. + + Additionally, the operation annotates `min_val` and `max_val` attributes + denoting the range constraints for the dynamic dimension. This may be + useful for modeling runtime shape guards, or compile-time optimizations + based on the shape bounds (min, opt, max) on results of ops / regions. + + Example: + ``` + %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int + ``` + }]; + let arguments = (ins + StrAttr:$symbol_name, + I64Attr:$min_val, + I64Attr:$max_val + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = [{ + $symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result) + }]; +} + +def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> { + let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols"; + let description = [{ + The `torch.bind_symbolic_shape` operation binds shape expressions + useful to compute the dynamic dimensions of a tensor. It takes a + variadic of SSA symbols that map 1:1 to the local symbols declared + in the affine map. The affine map contains a list of affine shape + expressions for each dim where the terminals are from the declared + symbols. + + Example: + ``` + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> + ``` + }]; + let arguments = (ins + Torch_ValueTensorType:$operand, + Variadic:$shape_symbols, + Builtin_AffineMapAttr:$shape_expressions + ); + let results = (outs); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // TORCH_OPS diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5e0f0ab1e..994722f3e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5034,3 +5034,65 @@ LogicalResult InitializeGlobalSlotsOp::verify() { return emitOpError("expected number of operands to match number of slots"); return success(); } + +//===----------------------------------------------------------------------===// +// BindSymbolicShapeOp +//===----------------------------------------------------------------------===// + +// +// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] -> +// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> +// + +ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + SmallVector shapeSymbols; + AffineMapAttr shapeExpressions; + Type operandType; + + if (parser.parseOperand(operand) || parser.parseComma() || + parser.parseLSquare() || parser.parseOperandList(shapeSymbols) || + parser.parseRSquare() || parser.parseComma() || + parser.parseAttribute(shapeExpressions, "shape_expressions", + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType)) { + return failure(); + } + + if (parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(shapeSymbols, + parser.getBuilder().getType(), + result.operands)) { + return failure(); + } + + return success(); +} + +// Use a custom printer here to avoid the AffineMap from getting hoisted +// when printed. This makes it so the AffineMap is printed inline with the op. +void BindSymbolicShapeOp::print(OpAsmPrinter &p) { + p << " " << getOperand() << ", ["; + llvm::interleaveComma(getShapeSymbols(), p); + p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"shape_expressions"}); + p << " : " << getOperand().getType(); +} + +LogicalResult BindSymbolicShapeOp::verify() { + if (getShapeSymbols().empty()) + return emitOpError() << "requires non-empty shapeSymbols"; + + for (auto symbol : getShapeSymbols()) { + Operation *definingOp = symbol.getDefiningOp(); + if (!isa(definingOp)) { + return emitOpError() + << "shape symbol must be produced by a SymbolicIntOp"; + } + } + + return success(); +} diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 4cda217a1..11a6ef6ff 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -49,6 +49,9 @@ class FxImporterTestConfig(TestConfig): prog, output_type=self._output_type, func_name=artifact.__class__.__name__, + # While the current e2e tests don't exercise symbolic shapes, + # enabling this here ensures they don't regress either. + import_symbolic_shape_expressions=True, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f328bc5d0..9dcb3c285 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -14,6 +14,8 @@ except ImportError: import logging import operator import re +import sympy +import math from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType from typing import ( @@ -81,6 +83,14 @@ from torch.fx.node import ( ) from ..ir import ( + AffineAddExpr, + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineModExpr, + AffineMulExpr, + AffineSymbolExpr, Attribute, Block, Context, @@ -258,6 +268,71 @@ else: SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} +@dataclass +class RangeConstraint: + min_val: int + max_val: int + + +def sympy_expr_to_semi_affine_expr( + expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr] +) -> AffineExpr: + """Translate sympy expressions to MLIR (semi-)affine expressions. + + Recursively traverse the sympy expr AST and build the affine expr. + This is not a perfect translation. Sympy expressions are much more + expressive and not as constrained as affine (linear) expressions are. + However, for the most part, we don't need to support all of sympy. + PyTorch only uses a subset of sympy for capturing and expressing + symbolic shapes, and among what's supported, we expect the semi-affine + expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps) + to be sufficient. + """ + if isinstance(expr, sympy.Symbol): + return symbols_map[str(expr)] + elif isinstance(expr, (int, sympy.Integer)): + return AffineConstantExpr.get(expr) + # This handles both add (`s0 + c`) and subtract (`s0 - c`). + # The expression is `sympy.Add` in both cases but with args + # (s0, c) in first case and (s0, -c) in the second case. + elif isinstance(expr, sympy.Add): + affine_expr = AffineConstantExpr.get(0) + for arg in expr.args: + affine_expr = AffineAddExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mul): + affine_expr = AffineConstantExpr.get(1) + for arg in expr.args: + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Pow): + base, exp = expr.args + # Only integer exponent is supported + # So, s1 ** s0 isn't allowed. + assert isinstance(exp, (int, sympy.Integer)) + assert exp > 0, "Only positive exponents supported in sympy.Pow" + affine_expr = AffineConstantExpr.get(1) + for _ in range(exp): + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mod): + dividend, divisor = expr.args + return AffineModExpr.get( + sympy_expr_to_semi_affine_expr(dividend, symbols_map), + sympy_expr_to_semi_affine_expr(divisor, symbols_map), + ) + else: + raise NotImplementedError( + f"Translation of sympy.Expr of type {type(expr)} not implemented yet." + ) + + @dataclass(frozen=True) class SparsityMeta: """ @@ -478,6 +553,7 @@ class FxImporter: *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports an ExportedProgram according to our chosen canonical representation. @@ -527,6 +603,10 @@ class FxImporter: sig = prog.graph_signature + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # Invert the (producer, node_name) maps for mutated user inputs and mutated # buffers. This is because we hit-detect based on the input node name. mutated_user_inputs = { @@ -682,7 +762,9 @@ class FxImporter: # Import all nodes and return. node_importer.import_nodes( - all_producer_nodes.values(), skip_placeholders_outputs=True + all_producer_nodes.values(), + skip_placeholders_outputs=True, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) @@ -694,6 +776,7 @@ class FxImporter: *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports a consolidated torch.export.ExportedProgram instance. @@ -728,6 +811,10 @@ class FxImporter: state_dict = prog.state_dict arg_replacements: Dict[str, Any] = {} + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 if hasattr(prog, "constants"): @@ -774,7 +861,10 @@ class FxImporter: g.erase_node(node) return self.import_stateless_graph( - g, func_name=func_name, func_visibility=func_visibility + g, + func_name=func_name, + func_visibility=func_visibility, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) def import_graph_module(self, gm: GraphModule) -> Operation: @@ -791,6 +881,7 @@ class FxImporter: *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Low-level import of a functionalized, assumed stateless Graph as a func. @@ -815,7 +906,9 @@ class FxImporter: self._cc, entry_block, ) - node_importer.import_nodes(g.nodes) + node_importer.import_nodes( + g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions + ) self.symbol_table.insert(func) return func @@ -870,6 +963,7 @@ class ContextCache: "_c", "_dtype_to_type", "_tensor_metadata_cache", + "_symbolic_guards", "_py_attr_tracker", # Types. "torch_bool_type", @@ -888,6 +982,7 @@ class ContextCache: self._tensor_metadata_cache: Dict[ Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType ] = {} + self._symbolic_guards: Dict = {} self._py_attr_tracker = py_attr_tracker or RefTracker() # Common types. @@ -1037,6 +1132,52 @@ class ContextCache: return Location.file(filename, line, col=0, context=self._c) return Location.unknown(context=self._c) + def set_symbolic_guards( + self, prog: torch.export.ExportedProgram + ) -> Dict[str, RangeConstraint]: + + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + # TODO: Remove this adjustment when fractional ranges are removed + return adjust_func(val) + + contains_symbolic_ints = False + for val in prog.range_constraints.values(): + if ( + isinstance(val.lower, sympy.Integer) + and isinstance(val.upper, sympy.Integer) + and not val.is_bool + ): + contains_symbolic_ints = True + break + if contains_symbolic_ints: + # Build a map from shape symbol name to `RangeConstraint` object + # capturing `min_val`` and `max_val`` constraints for that + # symbol. Translate sympy integers to regular integers. + # + # Example: + # { + # 's0': RangeConstraint(min_val=5, max_val=10), + # 's1': RangeConstraint(min_val=0, max_val=100), + # 's3': RangeConstraint(min_val=0, max_val=9223372036854775806), + # } + self._symbolic_guards = { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, math.ceil), + _sympy_int_to_int(v.upper, math.floor), + ) + for k, v in prog.range_constraints.items() + } + + def get_symbolic_guards(self) -> Dict[str, RangeConstraint]: + return self._symbolic_guards + class GraphNodeImporter: """Imports graph nodes into an MLIR function. @@ -1050,6 +1191,7 @@ class GraphNodeImporter: "_cc", "_on_node_produced", "_v", + "_symbol_to_value", "_multi_result_nodes", "fx_importer", ] @@ -1068,6 +1210,8 @@ class GraphNodeImporter: # Map of (Node, result_index) to MLIR Value or a callback that lazily # constructs and returns a value. self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} + # Map of Shape Symbol to MLIR Value + self._symbol_to_value: Dict[str, Value] = {} # Map of node name to hook that should be called when it is produced. self._on_node_produced: Dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. @@ -1108,6 +1252,28 @@ class GraphNodeImporter: self._v[key] = value return value + def bind_symbol_value( + self, + shape_symbol: str, + value: Value, + ): + """Binds a shape symbol to a global SSA value (and asserts if already bound).""" + assert ( + shape_symbol not in self._symbol_to_value + ), f"Symbol already has a value: {shape_symbol}" + self._symbol_to_value[shape_symbol] = value + + def resolve_symbol_value(self, shape_symbol: str) -> Value: + """Resolves a shape symbol to a value.""" + try: + binding = self._symbol_to_value[shape_symbol] + except KeyError: + raise KeyError( + f"Shape symbol {shape_symbol} has not been bound to an MLIR value" + ) + if isinstance(binding, Value): + return binding + def import_mutable_to_vtensor( self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str ) -> Value: @@ -1190,10 +1356,20 @@ class GraphNodeImporter: func_dialect.ReturnOp(operands, loc=loc) def import_nodes( - self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False + self, + nodes: Iterable[Node], + *, + skip_placeholders_outputs: bool = False, + import_symbolic_shape_expressions: bool = False, ): with InsertionPoint(self._b): loc = Location.unknown() + + # Import dynamic shape symbols and guards (if any) + if import_symbolic_shape_expressions: + symbolic_guards = self._cc.get_symbolic_guards() + self._import_shape_symbols_with_guards(loc, symbolic_guards) + num_placeholders = 0 for node in nodes: op = node.op @@ -1253,6 +1429,8 @@ class GraphNodeImporter: operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) + self._create_bind_symbolic_shape_ops(loc, node) + def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar temp_node = Node( @@ -1516,6 +1694,69 @@ class GraphNodeImporter: for i, value in enumerate(operation.results): self.bind_node_value(node, value, i) + def _import_shape_symbols_with_guards( + self, loc: Location, symbolic_guards: Dict[str, RangeConstraint] + ): + for symbol, constraints in symbolic_guards.items(): + # Create torch.sym_int ops + operation = Operation.create( + name="torch.symbolic_int", + attributes={ + "symbol_name": StringAttr.get(symbol), + "min_val": self._cc.integer_attr(constraints.min_val, 64), + "max_val": self._cc.integer_attr(constraints.max_val, 64), + }, + results=[self._cc.torch_int_type], + loc=loc, + ) + self.bind_symbol_value(symbol, operation.result) + + def _create_bind_symbolic_shape_ops(self, loc: Location, node: torch_fx.Node): + node_val = node.meta.get("val") + if (node_val is not None) and isinstance(node_val, TorchFakeTensor): + # Only create bind ops if the shapes contain symbolic sizes. + # Query the bool attribute `_has_symbolic_sizes_strides` on node.meta["val"]. + if node_val._has_symbolic_sizes_strides: + # Read node metadata to obtain shape symbols and expressions + symbols_set = set() + shape_exprs = [] + for s in node_val.size(): + if isinstance(s, torch.SymInt): + symbols_set.update(s.node.expr.free_symbols) + shape_exprs.append(s.node.expr) + else: + assert isinstance(s, int) + shape_exprs.append(s) + + # Map from sympy shape symbols to local symbols in the affine map + symbols_set = sorted(symbols_set, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) + for i, symbol in enumerate(symbols_set) + } + + # Convert symbolic shape expressions into affine expressions + affine_exprs = [ + sympy_expr_to_semi_affine_expr(expr, symbols_map) + for expr in shape_exprs + ] + + affine_map = AffineMap.get(0, len(symbols_set), affine_exprs) + + # Build operand list + operand_list = [] + operand_list.append(self.resolve_node_value(node)) + for symbol in symbols_map.keys(): + operand_list.append(self.resolve_symbol_value(symbol)) + + # Create torch.bind_symbolic_shape ops + Operation.create( + name="torch.bind_symbolic_shape", + attributes={"shape_expressions": AffineMapAttr.get(affine_map)}, + operands=operand_list, + loc=loc, + ) + def _import_argument( self, loc: Location, arg: NodeArgument, expected_jit_type=None ) -> Value: diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index b8765b659..5cd7d2d6e 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -54,6 +54,7 @@ def export_and_import( fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, + import_symbolic_shape_expressions: bool = False, hooks: Optional[FxImporterHooks] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", @@ -79,9 +80,17 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog, func_name=func_name) + fx_importer.import_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) else: - fx_importer.import_frozen_program(prog, func_name=func_name) + fx_importer.import_frozen_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) return _module_lowering( enable_ir_printing, OutputType.get(output_type), fx_importer.module diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 180b6aac5..250f11cf6 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3026,3 +3026,35 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (! %1 = torch.copy.to_tensor %0 : !torch.tensor return %1 : !torch.tensor } + + +// ----- + +// CHECK-LABEL: @torch.symbolic_int$canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +// CHECK-NOT: %[[S1:.*]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +// CHECK: %[[V1:.*]] = torch.aten.slice.Tensor %[[ARG1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: %[[V2:.*]] = torch.aten.add.Tensor %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V2]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: return %[[V2]] : !torch.vtensor<[?],f32> +func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + %1 = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int + torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1_0 = torch.constant.int 1 + %2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int9223372036854775807, %int1_0 : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + %int1_1 = torch.constant.int 1 + %3 = torch.aten.add.Tensor %arg0, %2, %int1_1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %3 : !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 63aa1e375..5b732788f 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -375,3 +375,22 @@ func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64] // expected-error @+1 {{invalid sparsity encoding attribute}} func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> + + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + // expected-error @+1 {{op requires non-empty shapeSymbols}} + torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %int0 = torch.constant.int 0 + // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}} + torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index fde318630..fbc8fdff3 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -89,6 +89,11 @@ def test_import_frozen_exported_program_with_func_name(): @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> +# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> +# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -100,7 +105,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): batch = Dim("batch") dynamic_shapes = {"x": {0: batch}} m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + torch.randn(3, 4), + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) @@ -108,6 +117,12 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): @run # CHECK-LABEL: test_broadcast_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.aten.size.int +# CHECK: torch.prim.ListConstruct +# CHECK: %[[EXPAND:.*]] = torch.aten.expand +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> def test_broadcast_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -127,7 +142,12 @@ def test_broadcast_with_dynamic_shapes(): } m = fx.export_and_import( - Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + x, + y, + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py new file mode 100644 index 000000000..3215e0f82 --- /dev/null +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -0,0 +1,463 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests of various op special forms that the fx_importer +# handles. + +import torch +import torch.export +import torch.nn as nn +from torch.export import Dim + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list +# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat(): + class TanhSigmoidCat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCat(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_symbolic_dim_differ_by_one +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[ARG0]], %[[SLICE]], {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ADD]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[ADD]] : !torch.vtensor<[?],f32> +def test_symbolic_dim_differ_by_one(): + class SymbolicDimDifferByOne(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y[1:] + + # Sample inputs + x = torch.randn(5) + y = torch.randn(6) + + # Dynamic dim constraints + dimx = Dim("dimx", min=3, max=6) + dimy = dimx + 1 + dynamic_shapes = { + "x": {0: dimx}, + "y": {0: dimy}, + } + + m = fx.export_and_import( + SymbolicDimDifferByOne(), + x, + y, + dynamic_shapes=dynamic_shapes, + experimental_support_mutation=True, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_outer_with_squared_shape +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[VIEW1:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW1]], [%[[S0]]], affine_map<()[s0] -> (s0, 1)> : !torch.vtensor<[?,1],f32> +# CHECK: %[[MUL:.+]] = torch.aten.mul.Tensor %[[VIEW1]], %[[ARG0]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %[[MUL]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: %[[VIEW2:.+]] = torch.aten.view %[[MUL]], {{.*}} : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW2]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[VIEW2]] : !torch.vtensor<[?],f32> +def test_outer_with_squared_shape(): + class OuterWithSquaredShape(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.outer(x, x).flatten() + + # Sample inputs + x = torch.rand(10) + + # Dynamic dim constraints + batch = Dim("batch") + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + OuterWithSquaredShape(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_static_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[2,1],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[SLICE1:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +# CHECK: %[[SLICE2:.+]] = torch.aten.slice.Tensor %[[SLICE1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> +# CHECK: return %[[SLICE2]] : !torch.vtensor<[2,1],f32> +def test_slice_tensor_static_output(): + class SliceTensorStaticOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[0:2, :1] + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + SliceTensorStaticOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_dynamic_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0 - 5)> : !torch.vtensor<[?],f32> +# CHECK: return %[[SLICE]] : !torch.vtensor<[?],f32> +def test_slice_tensor_dynamic_output(): + class SliceTensorDynamicOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[5:] + + # Sample inputs + x = torch.randn(10) + + # Dynamic dim constraints + dimx = Dim("dimx", min=5) + dynamic_shapes = {"x": {0: dimx}} + + m = fx.export_and_import( + SliceTensorDynamicOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_div_tensor_mixed_ranks +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[DIV:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[DIV]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[DIV]] : !torch.vtensor<[?,3],f32> +def test_div_tensor_mixed_ranks(): + class DivTensorMixedRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + div = torch.div(x, y) + return div + + # Sample inputs + x = torch.tensor(10.0) + y = torch.randn(2, 3) + + # Dynamic dim constraints + batch = Dim("batch") + dynamic_shapes = {"x": None, "y": {0: batch}} + + m = fx.export_and_import( + DivTensorMixedRanks(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_shape_div +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { +# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> +# CHECK: %[[VIEW:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?,7],f32>, !torch.list -> !torch.vtensor<[?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S1]]], affine_map<()[s0] -> (s0 * 7, 5)> : !torch.vtensor<[?,5],f32> +# CHECK: return %[[VIEW]] : !torch.vtensor<[?,5],f32> +def test_shape_div(): + class ShapeDiv(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 5) + + # Sample inputs + x = torch.rand(10, 7) + + # Dynamic dim constraints + batch = Dim("batch", max=1000) * 5 + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + ShapeDiv(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>) -> !torch.vtensor<[3,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (3, s0)> : !torch.vtensor<[3,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[3,?],f32> +def test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToStaticWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (3, -1)) + + # Sample inputs + x = torch.randn(1, 2) + + # Dynamic dim constraints + dim_1 = Dim("dim_1") + dynamic_shapes = {"x": {1: dim_1}} + + m = fx.export_and_import( + BroadcastUnitDimToStaticWithUnchangedDimDynamic(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> +# CHECK: return %3 : !torch.vtensor<[?,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static(): + class BroadcastUnitDimToDynamicWithUnchangedDimStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimStatic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S1]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s1, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,?],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToDynamicWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dim_1 = Dim("dim_1") + dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimDynamic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_rank_increase +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3,2],f32>) -> !torch.vtensor<[?,3,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,3,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,3,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_rank_increase(): + class BroadcastUnitDimToDynamicWithRankIncrease(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, y.size()) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(4, 3, 2) + + # Dynamic dim constraints + dim_0 = Dim("dim_0") + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithRankIncrease(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_gather_elements +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { +# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> +# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> +def test_gather_elements(): + class GatherElements(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.gather(x, 0, y) + + # Sample inputs + x = torch.randn(4, 3) + y = torch.tensor([[0, 0, 0], [1, 1, 1]]) + + # Dynamic dim constraints + batch = Dim("batch", min=3) + dynamic_shapes = {"x": {0: batch}, "y": {}} + + m = fx.export_and_import( + GatherElements(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python/fx_importer/sympy_to_affine_expr_test.py b/test/python/fx_importer/sympy_to_affine_expr_test.py new file mode 100644 index 000000000..0c366040d --- /dev/null +++ b/test/python/fx_importer/sympy_to_affine_expr_test.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests checking translating sympy expressions to (semi-)affine expressions. + +from sympy import Symbol +from torch_mlir.extras.fx_importer import sympy_expr_to_semi_affine_expr + +from torch_mlir.ir import ( + AffineSymbolExpr, + Context, +) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sympy_to_semi_affine_expr_translation +def test_sympy_to_semi_affine_expr_translation(): + with Context(): + s0 = Symbol("s0", positive=True, integer=True) + s1 = Symbol("s1", positive=True, integer=True) + + symbols_set = sorted({s0, s1}, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) for i, symbol in enumerate(symbols_set) + } + + SYMPY_EXPRS = [ + # CHECK: 10 + (10), + # CHECK: s0 + (s0), + # CHECK: s0 + (s0 + 0), + # CHECK: s0 + 1 + (s0 + 1), + # CHECK: s0 + (s0 * 1), + # CHECK: s0 * 2 + (s0 * 2), + # CHECK: s0 * s0 + (s0 * s0), + # CHECK: s0 * s1 + (s0 * s1), + # CHECK: s0 * s0 + (s0**2), + # CHECK: (s0 * s0) * s0 + (s0**3), + # CHECK: ((((s0 * s0) * s0) * s0) * s0) * s0 + ((s0**2) ** 3), + # CHECK: ((((((s0 * s0) * s0) * s0) * s0) * s0) * s0) * s0 + (s0 ** (2**3)), + # CHECK: s0 mod 10 + (s0 % 10), + # CHECK: s0 - s1 * 2 + 5 + (s0 + 5 - 2 * s1), + ] + + for expr in SYMPY_EXPRS: + print(sympy_expr_to_semi_affine_expr(expr, symbols_map)) diff --git a/test/python/fx_importer/v2.3/types_test.py b/test/python/fx_importer/v2.3/types_test.py index 19dee8b7b..eccea125c 100644 --- a/test/python/fx_importer/v2.3/types_test.py +++ b/test/python/fx_importer/v2.3/types_test.py @@ -36,8 +36,13 @@ def test_scalar_typed_node(): x = x + 1.0 return x.shape[0] + # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int + # CHECK: torch.bind_symbolic_shape %arg0, [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}} + Basic(), + torch.randn(3, 4), + dynamic_shapes={"x": {0: torch.export.Dim("b")}}, + import_symbolic_shape_expressions=True, ) print(m)