mirror of https://github.com/llvm/torch-mlir
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)pull/3430/head
parent
431d98b405
commit
d0a818a03e
|
@ -11,6 +11,7 @@
|
|||
#define TORCH_OPS
|
||||
|
||||
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "mlir/IR/BuiltinAttributes.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
|
@ -1337,4 +1338,67 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes",
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbolic shape modeling ops for TorchDynamo frontend.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> {
|
||||
let summary = "Symbolic int representing a dynamic dimension";
|
||||
let description = [{
|
||||
The `torch.symbolic_int` operation captures a dynamic dimension on the
|
||||
global function arguments as exported by TorchDynamo (torch.export).
|
||||
It associates the shape symbols (i.e. "s0", "s1") with the
|
||||
global SSA values (i.e. `%0`, `%1`) that is then referenced
|
||||
to bind shapes on op results.
|
||||
|
||||
Additionally, the operation annotates `min_val` and `max_val` attributes
|
||||
denoting the range constraints for the dynamic dimension. This may be
|
||||
useful for modeling runtime shape guards, or compile-time optimizations
|
||||
based on the shape bounds (min, opt, max) on results of ops / regions.
|
||||
|
||||
Example:
|
||||
```
|
||||
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
|
||||
%1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int
|
||||
```
|
||||
}];
|
||||
let arguments = (ins
|
||||
StrAttr:$symbol_name,
|
||||
I64Attr:$min_val,
|
||||
I64Attr:$max_val
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> {
|
||||
let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols";
|
||||
let description = [{
|
||||
The `torch.bind_symbolic_shape` operation binds shape expressions
|
||||
useful to compute the dynamic dimensions of a tensor. It takes a
|
||||
variadic of SSA symbols that map 1:1 to the local symbols declared
|
||||
in the affine map. The affine map contains a list of affine shape
|
||||
expressions for each dim where the terminals are from the declared
|
||||
symbols.
|
||||
|
||||
Example:
|
||||
```
|
||||
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins
|
||||
Torch_ValueTensorType:$operand,
|
||||
Variadic<Torch_IntType>:$shape_symbols,
|
||||
Builtin_AffineMapAttr:$shape_expressions
|
||||
);
|
||||
let results = (outs);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -5034,3 +5034,65 @@ LogicalResult InitializeGlobalSlotsOp::verify() {
|
|||
return emitOpError("expected number of operands to match number of slots");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BindSymbolicShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//
|
||||
// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] ->
|
||||
// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
//
|
||||
|
||||
ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::UnresolvedOperand operand;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> shapeSymbols;
|
||||
AffineMapAttr shapeExpressions;
|
||||
Type operandType;
|
||||
|
||||
if (parser.parseOperand(operand) || parser.parseComma() ||
|
||||
parser.parseLSquare() || parser.parseOperandList(shapeSymbols) ||
|
||||
parser.parseRSquare() || parser.parseComma() ||
|
||||
parser.parseAttribute(shapeExpressions, "shape_expressions",
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(operandType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.resolveOperand(operand, operandType, result.operands) ||
|
||||
parser.resolveOperands(shapeSymbols,
|
||||
parser.getBuilder().getType<Torch::IntType>(),
|
||||
result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use a custom printer here to avoid the AffineMap from getting hoisted
|
||||
// when printed. This makes it so the AffineMap is printed inline with the op.
|
||||
void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
||||
p << " " << getOperand() << ", [";
|
||||
llvm::interleaveComma(getShapeSymbols(), p);
|
||||
p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">";
|
||||
p.printOptionalAttrDict((*this)->getAttrs(),
|
||||
/*elidedAttrs=*/{"shape_expressions"});
|
||||
p << " : " << getOperand().getType();
|
||||
}
|
||||
|
||||
LogicalResult BindSymbolicShapeOp::verify() {
|
||||
if (getShapeSymbols().empty())
|
||||
return emitOpError() << "requires non-empty shapeSymbols";
|
||||
|
||||
for (auto symbol : getShapeSymbols()) {
|
||||
Operation *definingOp = symbol.getDefiningOp();
|
||||
if (!isa<SymbolicIntOp>(definingOp)) {
|
||||
return emitOpError()
|
||||
<< "shape symbol must be produced by a SymbolicIntOp";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -49,6 +49,9 @@ class FxImporterTestConfig(TestConfig):
|
|||
prog,
|
||||
output_type=self._output_type,
|
||||
func_name=artifact.__class__.__name__,
|
||||
# While the current e2e tests don't exercise symbolic shapes,
|
||||
# enabling this here ensures they don't regress either.
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
module = self._backend.compile(module)
|
||||
backend_module = self._backend.load(module)
|
||||
|
|
|
@ -14,6 +14,8 @@ except ImportError:
|
|||
import logging
|
||||
import operator
|
||||
import re
|
||||
import sympy
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from types import BuiltinMethodType, BuiltinFunctionType
|
||||
from typing import (
|
||||
|
@ -81,6 +83,14 @@ from torch.fx.node import (
|
|||
)
|
||||
|
||||
from ..ir import (
|
||||
AffineAddExpr,
|
||||
AffineConstantExpr,
|
||||
AffineExpr,
|
||||
AffineMap,
|
||||
AffineMapAttr,
|
||||
AffineModExpr,
|
||||
AffineMulExpr,
|
||||
AffineSymbolExpr,
|
||||
Attribute,
|
||||
Block,
|
||||
Context,
|
||||
|
@ -258,6 +268,71 @@ else:
|
|||
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConstraint:
|
||||
min_val: int
|
||||
max_val: int
|
||||
|
||||
|
||||
def sympy_expr_to_semi_affine_expr(
|
||||
expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr]
|
||||
) -> AffineExpr:
|
||||
"""Translate sympy expressions to MLIR (semi-)affine expressions.
|
||||
|
||||
Recursively traverse the sympy expr AST and build the affine expr.
|
||||
This is not a perfect translation. Sympy expressions are much more
|
||||
expressive and not as constrained as affine (linear) expressions are.
|
||||
However, for the most part, we don't need to support all of sympy.
|
||||
PyTorch only uses a subset of sympy for capturing and expressing
|
||||
symbolic shapes, and among what's supported, we expect the semi-affine
|
||||
expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps)
|
||||
to be sufficient.
|
||||
"""
|
||||
if isinstance(expr, sympy.Symbol):
|
||||
return symbols_map[str(expr)]
|
||||
elif isinstance(expr, (int, sympy.Integer)):
|
||||
return AffineConstantExpr.get(expr)
|
||||
# This handles both add (`s0 + c`) and subtract (`s0 - c`).
|
||||
# The expression is `sympy.Add` in both cases but with args
|
||||
# (s0, c) in first case and (s0, -c) in the second case.
|
||||
elif isinstance(expr, sympy.Add):
|
||||
affine_expr = AffineConstantExpr.get(0)
|
||||
for arg in expr.args:
|
||||
affine_expr = AffineAddExpr.get(
|
||||
affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map)
|
||||
)
|
||||
return affine_expr
|
||||
elif isinstance(expr, sympy.Mul):
|
||||
affine_expr = AffineConstantExpr.get(1)
|
||||
for arg in expr.args:
|
||||
affine_expr = AffineMulExpr.get(
|
||||
affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map)
|
||||
)
|
||||
return affine_expr
|
||||
elif isinstance(expr, sympy.Pow):
|
||||
base, exp = expr.args
|
||||
# Only integer exponent is supported
|
||||
# So, s1 ** s0 isn't allowed.
|
||||
assert isinstance(exp, (int, sympy.Integer))
|
||||
assert exp > 0, "Only positive exponents supported in sympy.Pow"
|
||||
affine_expr = AffineConstantExpr.get(1)
|
||||
for _ in range(exp):
|
||||
affine_expr = AffineMulExpr.get(
|
||||
affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map)
|
||||
)
|
||||
return affine_expr
|
||||
elif isinstance(expr, sympy.Mod):
|
||||
dividend, divisor = expr.args
|
||||
return AffineModExpr.get(
|
||||
sympy_expr_to_semi_affine_expr(dividend, symbols_map),
|
||||
sympy_expr_to_semi_affine_expr(divisor, symbols_map),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Translation of sympy.Expr of type {type(expr)} not implemented yet."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SparsityMeta:
|
||||
"""
|
||||
|
@ -478,6 +553,7 @@ class FxImporter:
|
|||
*,
|
||||
func_name: str = "main",
|
||||
func_visibility: Optional[str] = None,
|
||||
import_symbolic_shape_expressions: bool = False,
|
||||
) -> Operation:
|
||||
"""Imports an ExportedProgram according to our chosen canonical representation.
|
||||
|
||||
|
@ -527,6 +603,10 @@ class FxImporter:
|
|||
|
||||
sig = prog.graph_signature
|
||||
|
||||
# Populate symbolic guards for dynamic shapes (if any)
|
||||
if import_symbolic_shape_expressions:
|
||||
self._cc.set_symbolic_guards(prog)
|
||||
|
||||
# Invert the (producer, node_name) maps for mutated user inputs and mutated
|
||||
# buffers. This is because we hit-detect based on the input node name.
|
||||
mutated_user_inputs = {
|
||||
|
@ -682,7 +762,9 @@ class FxImporter:
|
|||
|
||||
# Import all nodes and return.
|
||||
node_importer.import_nodes(
|
||||
all_producer_nodes.values(), skip_placeholders_outputs=True
|
||||
all_producer_nodes.values(),
|
||||
skip_placeholders_outputs=True,
|
||||
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
||||
)
|
||||
node_importer.return_node_values(loc, user_outputs)
|
||||
self.symbol_table.insert(func_op)
|
||||
|
@ -694,6 +776,7 @@ class FxImporter:
|
|||
*,
|
||||
func_name: str = "main",
|
||||
func_visibility: Optional[str] = None,
|
||||
import_symbolic_shape_expressions: bool = False,
|
||||
) -> Operation:
|
||||
"""Imports a consolidated torch.export.ExportedProgram instance.
|
||||
|
||||
|
@ -728,6 +811,10 @@ class FxImporter:
|
|||
state_dict = prog.state_dict
|
||||
arg_replacements: Dict[str, Any] = {}
|
||||
|
||||
# Populate symbolic guards for dynamic shapes (if any)
|
||||
if import_symbolic_shape_expressions:
|
||||
self._cc.set_symbolic_guards(prog)
|
||||
|
||||
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
||||
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
||||
if hasattr(prog, "constants"):
|
||||
|
@ -774,7 +861,10 @@ class FxImporter:
|
|||
g.erase_node(node)
|
||||
|
||||
return self.import_stateless_graph(
|
||||
g, func_name=func_name, func_visibility=func_visibility
|
||||
g,
|
||||
func_name=func_name,
|
||||
func_visibility=func_visibility,
|
||||
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
||||
)
|
||||
|
||||
def import_graph_module(self, gm: GraphModule) -> Operation:
|
||||
|
@ -791,6 +881,7 @@ class FxImporter:
|
|||
*,
|
||||
func_name: str = "main",
|
||||
func_visibility: Optional[str] = None,
|
||||
import_symbolic_shape_expressions: bool = False,
|
||||
) -> Operation:
|
||||
"""Low-level import of a functionalized, assumed stateless Graph as a func.
|
||||
|
||||
|
@ -815,7 +906,9 @@ class FxImporter:
|
|||
self._cc,
|
||||
entry_block,
|
||||
)
|
||||
node_importer.import_nodes(g.nodes)
|
||||
node_importer.import_nodes(
|
||||
g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions
|
||||
)
|
||||
self.symbol_table.insert(func)
|
||||
return func
|
||||
|
||||
|
@ -870,6 +963,7 @@ class ContextCache:
|
|||
"_c",
|
||||
"_dtype_to_type",
|
||||
"_tensor_metadata_cache",
|
||||
"_symbolic_guards",
|
||||
"_py_attr_tracker",
|
||||
# Types.
|
||||
"torch_bool_type",
|
||||
|
@ -888,6 +982,7 @@ class ContextCache:
|
|||
self._tensor_metadata_cache: Dict[
|
||||
Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType
|
||||
] = {}
|
||||
self._symbolic_guards: Dict = {}
|
||||
self._py_attr_tracker = py_attr_tracker or RefTracker()
|
||||
|
||||
# Common types.
|
||||
|
@ -1037,6 +1132,52 @@ class ContextCache:
|
|||
return Location.file(filename, line, col=0, context=self._c)
|
||||
return Location.unknown(context=self._c)
|
||||
|
||||
def set_symbolic_guards(
|
||||
self, prog: torch.export.ExportedProgram
|
||||
) -> Dict[str, RangeConstraint]:
|
||||
|
||||
def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
|
||||
# Convert simple sympy Integers into concrete int
|
||||
if val == sympy.oo:
|
||||
return math.inf
|
||||
if val == -sympy.oo:
|
||||
return -math.inf
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
# TODO: Remove this adjustment when fractional ranges are removed
|
||||
return adjust_func(val)
|
||||
|
||||
contains_symbolic_ints = False
|
||||
for val in prog.range_constraints.values():
|
||||
if (
|
||||
isinstance(val.lower, sympy.Integer)
|
||||
and isinstance(val.upper, sympy.Integer)
|
||||
and not val.is_bool
|
||||
):
|
||||
contains_symbolic_ints = True
|
||||
break
|
||||
if contains_symbolic_ints:
|
||||
# Build a map from shape symbol name to `RangeConstraint` object
|
||||
# capturing `min_val`` and `max_val`` constraints for that
|
||||
# symbol. Translate sympy integers to regular integers.
|
||||
#
|
||||
# Example:
|
||||
# {
|
||||
# 's0': RangeConstraint(min_val=5, max_val=10),
|
||||
# 's1': RangeConstraint(min_val=0, max_val=100),
|
||||
# 's3': RangeConstraint(min_val=0, max_val=9223372036854775806),
|
||||
# }
|
||||
self._symbolic_guards = {
|
||||
str(k): RangeConstraint(
|
||||
_sympy_int_to_int(v.lower, math.ceil),
|
||||
_sympy_int_to_int(v.upper, math.floor),
|
||||
)
|
||||
for k, v in prog.range_constraints.items()
|
||||
}
|
||||
|
||||
def get_symbolic_guards(self) -> Dict[str, RangeConstraint]:
|
||||
return self._symbolic_guards
|
||||
|
||||
|
||||
class GraphNodeImporter:
|
||||
"""Imports graph nodes into an MLIR function.
|
||||
|
@ -1050,6 +1191,7 @@ class GraphNodeImporter:
|
|||
"_cc",
|
||||
"_on_node_produced",
|
||||
"_v",
|
||||
"_symbol_to_value",
|
||||
"_multi_result_nodes",
|
||||
"fx_importer",
|
||||
]
|
||||
|
@ -1068,6 +1210,8 @@ class GraphNodeImporter:
|
|||
# Map of (Node, result_index) to MLIR Value or a callback that lazily
|
||||
# constructs and returns a value.
|
||||
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
|
||||
# Map of Shape Symbol to MLIR Value
|
||||
self._symbol_to_value: Dict[str, Value] = {}
|
||||
# Map of node name to hook that should be called when it is produced.
|
||||
self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
|
||||
# Statically multi-result nodes which we have de-tupled are noted here.
|
||||
|
@ -1108,6 +1252,28 @@ class GraphNodeImporter:
|
|||
self._v[key] = value
|
||||
return value
|
||||
|
||||
def bind_symbol_value(
|
||||
self,
|
||||
shape_symbol: str,
|
||||
value: Value,
|
||||
):
|
||||
"""Binds a shape symbol to a global SSA value (and asserts if already bound)."""
|
||||
assert (
|
||||
shape_symbol not in self._symbol_to_value
|
||||
), f"Symbol already has a value: {shape_symbol}"
|
||||
self._symbol_to_value[shape_symbol] = value
|
||||
|
||||
def resolve_symbol_value(self, shape_symbol: str) -> Value:
|
||||
"""Resolves a shape symbol to a value."""
|
||||
try:
|
||||
binding = self._symbol_to_value[shape_symbol]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Shape symbol {shape_symbol} has not been bound to an MLIR value"
|
||||
)
|
||||
if isinstance(binding, Value):
|
||||
return binding
|
||||
|
||||
def import_mutable_to_vtensor(
|
||||
self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str
|
||||
) -> Value:
|
||||
|
@ -1190,10 +1356,20 @@ class GraphNodeImporter:
|
|||
func_dialect.ReturnOp(operands, loc=loc)
|
||||
|
||||
def import_nodes(
|
||||
self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
*,
|
||||
skip_placeholders_outputs: bool = False,
|
||||
import_symbolic_shape_expressions: bool = False,
|
||||
):
|
||||
with InsertionPoint(self._b):
|
||||
loc = Location.unknown()
|
||||
|
||||
# Import dynamic shape symbols and guards (if any)
|
||||
if import_symbolic_shape_expressions:
|
||||
symbolic_guards = self._cc.get_symbolic_guards()
|
||||
self._import_shape_symbols_with_guards(loc, symbolic_guards)
|
||||
|
||||
num_placeholders = 0
|
||||
for node in nodes:
|
||||
op = node.op
|
||||
|
@ -1253,6 +1429,8 @@ class GraphNodeImporter:
|
|||
operands = [self._import_argument(loc, arg) for arg in node.args[0]]
|
||||
func_dialect.ReturnOp(operands, loc=loc)
|
||||
|
||||
self._create_bind_symbolic_shape_ops(loc, node)
|
||||
|
||||
def _promote_symbolic_scalar_int_float(self, loc, graph, param):
|
||||
temp_target = torch.ops.aten.Float.Scalar
|
||||
temp_node = Node(
|
||||
|
@ -1516,6 +1694,69 @@ class GraphNodeImporter:
|
|||
for i, value in enumerate(operation.results):
|
||||
self.bind_node_value(node, value, i)
|
||||
|
||||
def _import_shape_symbols_with_guards(
|
||||
self, loc: Location, symbolic_guards: Dict[str, RangeConstraint]
|
||||
):
|
||||
for symbol, constraints in symbolic_guards.items():
|
||||
# Create torch.sym_int ops
|
||||
operation = Operation.create(
|
||||
name="torch.symbolic_int",
|
||||
attributes={
|
||||
"symbol_name": StringAttr.get(symbol),
|
||||
"min_val": self._cc.integer_attr(constraints.min_val, 64),
|
||||
"max_val": self._cc.integer_attr(constraints.max_val, 64),
|
||||
},
|
||||
results=[self._cc.torch_int_type],
|
||||
loc=loc,
|
||||
)
|
||||
self.bind_symbol_value(symbol, operation.result)
|
||||
|
||||
def _create_bind_symbolic_shape_ops(self, loc: Location, node: torch_fx.Node):
|
||||
node_val = node.meta.get("val")
|
||||
if (node_val is not None) and isinstance(node_val, TorchFakeTensor):
|
||||
# Only create bind ops if the shapes contain symbolic sizes.
|
||||
# Query the bool attribute `_has_symbolic_sizes_strides` on node.meta["val"].
|
||||
if node_val._has_symbolic_sizes_strides:
|
||||
# Read node metadata to obtain shape symbols and expressions
|
||||
symbols_set = set()
|
||||
shape_exprs = []
|
||||
for s in node_val.size():
|
||||
if isinstance(s, torch.SymInt):
|
||||
symbols_set.update(s.node.expr.free_symbols)
|
||||
shape_exprs.append(s.node.expr)
|
||||
else:
|
||||
assert isinstance(s, int)
|
||||
shape_exprs.append(s)
|
||||
|
||||
# Map from sympy shape symbols to local symbols in the affine map
|
||||
symbols_set = sorted(symbols_set, key=lambda x: x.name)
|
||||
symbols_map = {
|
||||
str(symbol): AffineSymbolExpr.get(i)
|
||||
for i, symbol in enumerate(symbols_set)
|
||||
}
|
||||
|
||||
# Convert symbolic shape expressions into affine expressions
|
||||
affine_exprs = [
|
||||
sympy_expr_to_semi_affine_expr(expr, symbols_map)
|
||||
for expr in shape_exprs
|
||||
]
|
||||
|
||||
affine_map = AffineMap.get(0, len(symbols_set), affine_exprs)
|
||||
|
||||
# Build operand list
|
||||
operand_list = []
|
||||
operand_list.append(self.resolve_node_value(node))
|
||||
for symbol in symbols_map.keys():
|
||||
operand_list.append(self.resolve_symbol_value(symbol))
|
||||
|
||||
# Create torch.bind_symbolic_shape ops
|
||||
Operation.create(
|
||||
name="torch.bind_symbolic_shape",
|
||||
attributes={"shape_expressions": AffineMapAttr.get(affine_map)},
|
||||
operands=operand_list,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
def _import_argument(
|
||||
self, loc: Location, arg: NodeArgument, expected_jit_type=None
|
||||
) -> Value:
|
||||
|
|
|
@ -54,6 +54,7 @@ def export_and_import(
|
|||
fx_importer: Optional[FxImporter] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
experimental_support_mutation: bool = False,
|
||||
import_symbolic_shape_expressions: bool = False,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
func_name: str = "main",
|
||||
|
@ -79,9 +80,17 @@ def export_and_import(
|
|||
if experimental_support_mutation:
|
||||
if torch.__version__ < "2.3.0.dev20240207":
|
||||
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
||||
fx_importer.import_program(prog, func_name=func_name)
|
||||
fx_importer.import_program(
|
||||
prog,
|
||||
func_name=func_name,
|
||||
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
||||
)
|
||||
else:
|
||||
fx_importer.import_frozen_program(prog, func_name=func_name)
|
||||
fx_importer.import_frozen_program(
|
||||
prog,
|
||||
func_name=func_name,
|
||||
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
||||
)
|
||||
|
||||
return _module_lowering(
|
||||
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
||||
|
|
|
@ -3026,3 +3026,35 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
|
|||
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.symbolic_int$canonicalize(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
// CHECK-NOT: %[[S1:.*]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
|
||||
// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[V1:.*]] = torch.aten.slice.Tensor %[[ARG1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: torch.bind_symbolic_shape %[[V1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[V2:.*]] = torch.aten.add.Tensor %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: torch.bind_symbolic_shape %[[V2]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[V2]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
%1 = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
|
||||
torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int9223372036854775807 = torch.constant.int 9223372036854775807
|
||||
%int1_0 = torch.constant.int 1
|
||||
%2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int9223372036854775807, %int1_0 : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
torch.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
%int1_1 = torch.constant.int 1
|
||||
%3 = torch.aten.add.Tensor %arg0, %2, %int1_1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
return %3 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
|
|
@ -375,3 +375,22 @@ func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64]
|
|||
|
||||
// expected-error @+1 {{invalid sparsity encoding attribute}}
|
||||
func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
// expected-error @+1 {{op requires non-empty shapeSymbols}}
|
||||
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
|
||||
torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
return %arg0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
|
|
@ -89,6 +89,11 @@ def test_import_frozen_exported_program_with_func_name():
|
|||
@run
|
||||
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
|
||||
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
|
||||
def test_import_frozen_exported_program_with_dynamic_shapes():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -100,7 +105,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
batch = Dim("batch")
|
||||
dynamic_shapes = {"x": {0: batch}}
|
||||
m = fx.export_and_import(
|
||||
Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net"
|
||||
Basic(),
|
||||
torch.randn(3, 4),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
func_name="test_net",
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
@ -108,6 +117,12 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
@run
|
||||
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.aten.size.int
|
||||
# CHECK: torch.prim.ListConstruct
|
||||
# CHECK: %[[EXPAND:.*]] = torch.aten.expand
|
||||
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
|
||||
def test_broadcast_with_dynamic_shapes():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -127,7 +142,12 @@ def test_broadcast_with_dynamic_shapes():
|
|||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
|
||||
Basic(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
func_name="test_net",
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
|
|
@ -0,0 +1,463 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests of various op special forms that the fx_importer
|
||||
# handles.
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
from torch.export import Dim
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_tanh_sigmoid_cat
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
||||
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
||||
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
|
||||
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
|
||||
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
|
||||
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
|
||||
# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32>
|
||||
def test_tanh_sigmoid_cat():
|
||||
class TanhSigmoidCat(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y, z):
|
||||
a = torch.tanh(x)
|
||||
b = torch.sigmoid(y)
|
||||
return torch.cat((a, a, b, z), dim=1)
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(5, 2, 3)
|
||||
y = torch.randn(5, 6, 3)
|
||||
z = torch.randn(5, 4, 3)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_n = Dim("n", min=5, max=10)
|
||||
dim_x1 = Dim("x1", max=100)
|
||||
dim_y1 = Dim("y1", max=50)
|
||||
dim_z1 = Dim("z1")
|
||||
dynamic_shapes = {
|
||||
"x": {0: dim_n, 1: dim_x1},
|
||||
"y": {0: dim_n, 1: dim_y1},
|
||||
"z": {0: dim_n, 1: dim_z1},
|
||||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
TanhSigmoidCat(),
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_symbolic_dim_differ_by_one
|
||||
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
||||
# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[ARG0]], %[[SLICE]], {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ADD]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: return %[[ADD]] : !torch.vtensor<[?],f32>
|
||||
def test_symbolic_dim_differ_by_one():
|
||||
class SymbolicDimDifferByOne(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y[1:]
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(5)
|
||||
y = torch.randn(6)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dimx = Dim("dimx", min=3, max=6)
|
||||
dimy = dimx + 1
|
||||
dynamic_shapes = {
|
||||
"x": {0: dimx},
|
||||
"y": {0: dimy},
|
||||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
SymbolicDimDifferByOne(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
experimental_support_mutation=True,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_outer_with_squared_shape
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[VIEW1:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[VIEW1]], [%[[S0]]], affine_map<()[s0] -> (s0, 1)> : !torch.vtensor<[?,1],f32>
|
||||
# CHECK: %[[MUL:.+]] = torch.aten.mul.Tensor %[[VIEW1]], %[[ARG0]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[MUL]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
|
||||
# CHECK: %[[VIEW2:.+]] = torch.aten.view %[[MUL]], {{.*}} : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[VIEW2]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: return %[[VIEW2]] : !torch.vtensor<[?],f32>
|
||||
def test_outer_with_squared_shape():
|
||||
class OuterWithSquaredShape(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.outer(x, x).flatten()
|
||||
|
||||
# Sample inputs
|
||||
x = torch.rand(10)
|
||||
|
||||
# Dynamic dim constraints
|
||||
batch = Dim("batch")
|
||||
dynamic_shapes = {"x": {0: batch}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
OuterWithSquaredShape(),
|
||||
x,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_slice_tensor_static_output
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[2,1],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: %[[SLICE1:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
|
||||
# CHECK: %[[SLICE2:.+]] = torch.aten.slice.Tensor %[[SLICE1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
|
||||
# CHECK: return %[[SLICE2]] : !torch.vtensor<[2,1],f32>
|
||||
def test_slice_tensor_static_output():
|
||||
class SliceTensorStaticOutput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[0:2, :1]
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(4, 3)
|
||||
|
||||
# Dynamic dim constraints
|
||||
batch = Dim("batch", min=3)
|
||||
dynamic_shapes = {"x": {0: batch}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
SliceTensorStaticOutput(),
|
||||
x,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_slice_tensor_dynamic_output
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 9223372036854775806} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0 - 5)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: return %[[SLICE]] : !torch.vtensor<[?],f32>
|
||||
def test_slice_tensor_dynamic_output():
|
||||
class SliceTensorDynamicOutput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x[5:]
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(10)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dimx = Dim("dimx", min=5)
|
||||
dynamic_shapes = {"x": {0: dimx}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
SliceTensorDynamicOutput(),
|
||||
x,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_div_tensor_mixed_ranks
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: %[[DIV:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[DIV]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: return %[[DIV]] : !torch.vtensor<[?,3],f32>
|
||||
def test_div_tensor_mixed_ranks():
|
||||
class DivTensorMixedRanks(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
div = torch.div(x, y)
|
||||
return div
|
||||
|
||||
# Sample inputs
|
||||
x = torch.tensor(10.0)
|
||||
y = torch.randn(2, 3)
|
||||
|
||||
# Dynamic dim constraints
|
||||
batch = Dim("batch")
|
||||
dynamic_shapes = {"x": None, "y": {0: batch}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
DivTensorMixedRanks(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_shape_div
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> {
|
||||
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
||||
# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int
|
||||
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32>
|
||||
# CHECK: %[[VIEW:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?,7],f32>, !torch.list<int> -> !torch.vtensor<[?,5],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S1]]], affine_map<()[s0] -> (s0 * 7, 5)> : !torch.vtensor<[?,5],f32>
|
||||
# CHECK: return %[[VIEW]] : !torch.vtensor<[?,5],f32>
|
||||
def test_shape_div():
|
||||
class ShapeDiv(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.reshape(-1, 5)
|
||||
|
||||
# Sample inputs
|
||||
x = torch.rand(10, 7)
|
||||
|
||||
# Dynamic dim constraints
|
||||
batch = Dim("batch", max=1000) * 5
|
||||
dynamic_shapes = {"x": {0: batch}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
ShapeDiv(),
|
||||
x,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>) -> !torch.vtensor<[3,?],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32>
|
||||
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (3, s0)> : !torch.vtensor<[3,?],f32>
|
||||
# CHECK: return %[[EXPAND]] : !torch.vtensor<[3,?],f32>
|
||||
def test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic():
|
||||
class BroadcastUnitDimToStaticWithUnchangedDimDynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.broadcast_to(x, (3, -1))
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_1 = Dim("dim_1")
|
||||
dynamic_shapes = {"x": {1: dim_1}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
BroadcastUnitDimToStaticWithUnchangedDimDynamic(),
|
||||
x,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,2],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
|
||||
# CHECK: return %3 : !torch.vtensor<[?,2],f32>
|
||||
def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static():
|
||||
class BroadcastUnitDimToDynamicWithUnchangedDimStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.broadcast_to(x, (y.shape[0], -1))
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
y = torch.randn(10)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_0 = Dim("dim_0")
|
||||
dynamic_shapes = {"x": {}, "y": {0: dim_0}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
BroadcastUnitDimToDynamicWithUnchangedDimStatic(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S1]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s1, s0)> : !torch.vtensor<[?,?],f32>
|
||||
# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,?],f32>
|
||||
def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic():
|
||||
class BroadcastUnitDimToDynamicWithUnchangedDimDynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.broadcast_to(x, (y.shape[0], -1))
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
y = torch.randn(10)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_0 = Dim("dim_0")
|
||||
dim_1 = Dim("dim_1")
|
||||
dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
BroadcastUnitDimToDynamicWithUnchangedDimDynamic(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_rank_increase
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3,2],f32>) -> !torch.vtensor<[?,3,2],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32>
|
||||
# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,3,2],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32>
|
||||
# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,3,2],f32>
|
||||
def test_broadcast_unit_dim_to_dynamic_with_rank_increase():
|
||||
class BroadcastUnitDimToDynamicWithRankIncrease(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.broadcast_to(x, y.size())
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(1, 2)
|
||||
y = torch.randn(4, 3, 2)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_0 = Dim("dim_0")
|
||||
dynamic_shapes = {"x": {}, "y": {0: dim_0}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
BroadcastUnitDimToDynamicWithRankIncrease(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_gather_elements
|
||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> {
|
||||
# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||
# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32>
|
||||
# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32>
|
||||
def test_gather_elements():
|
||||
class GatherElements(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.gather(x, 0, y)
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(4, 3)
|
||||
y = torch.tensor([[0, 0, 0], [1, 1, 1]])
|
||||
|
||||
# Dynamic dim constraints
|
||||
batch = Dim("batch", min=3)
|
||||
dynamic_shapes = {"x": {0: batch}, "y": {}}
|
||||
|
||||
m = fx.export_and_import(
|
||||
GatherElements(),
|
||||
x,
|
||||
y,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
|
@ -0,0 +1,69 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This file contains tests checking translating sympy expressions to (semi-)affine expressions.
|
||||
|
||||
from sympy import Symbol
|
||||
from torch_mlir.extras.fx_importer import sympy_expr_to_semi_affine_expr
|
||||
|
||||
from torch_mlir.ir import (
|
||||
AffineSymbolExpr,
|
||||
Context,
|
||||
)
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_sympy_to_semi_affine_expr_translation
|
||||
def test_sympy_to_semi_affine_expr_translation():
|
||||
with Context():
|
||||
s0 = Symbol("s0", positive=True, integer=True)
|
||||
s1 = Symbol("s1", positive=True, integer=True)
|
||||
|
||||
symbols_set = sorted({s0, s1}, key=lambda x: x.name)
|
||||
symbols_map = {
|
||||
str(symbol): AffineSymbolExpr.get(i) for i, symbol in enumerate(symbols_set)
|
||||
}
|
||||
|
||||
SYMPY_EXPRS = [
|
||||
# CHECK: 10
|
||||
(10),
|
||||
# CHECK: s0
|
||||
(s0),
|
||||
# CHECK: s0
|
||||
(s0 + 0),
|
||||
# CHECK: s0 + 1
|
||||
(s0 + 1),
|
||||
# CHECK: s0
|
||||
(s0 * 1),
|
||||
# CHECK: s0 * 2
|
||||
(s0 * 2),
|
||||
# CHECK: s0 * s0
|
||||
(s0 * s0),
|
||||
# CHECK: s0 * s1
|
||||
(s0 * s1),
|
||||
# CHECK: s0 * s0
|
||||
(s0**2),
|
||||
# CHECK: (s0 * s0) * s0
|
||||
(s0**3),
|
||||
# CHECK: ((((s0 * s0) * s0) * s0) * s0) * s0
|
||||
((s0**2) ** 3),
|
||||
# CHECK: ((((((s0 * s0) * s0) * s0) * s0) * s0) * s0) * s0
|
||||
(s0 ** (2**3)),
|
||||
# CHECK: s0 mod 10
|
||||
(s0 % 10),
|
||||
# CHECK: s0 - s1 * 2 + 5
|
||||
(s0 + 5 - 2 * s1),
|
||||
]
|
||||
|
||||
for expr in SYMPY_EXPRS:
|
||||
print(sympy_expr_to_semi_affine_expr(expr, symbols_map))
|
|
@ -36,8 +36,13 @@ def test_scalar_typed_node():
|
|||
x = x + 1.0
|
||||
return x.shape[0]
|
||||
|
||||
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %arg0, [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int
|
||||
m = fx.export_and_import(
|
||||
Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}}
|
||||
Basic(),
|
||||
torch.randn(3, 4),
|
||||
dynamic_shapes={"x": {0: torch.export.Dim("b")}},
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue