mirror of https://github.com/llvm/torch-mlir
d0a818a03e
Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on) |
||
---|---|---|
.. | ||
auto_functionalized.py | ||
lit.local.cfg | ||
mutation_import.py | ||
special_forms_test.py | ||
types_test.py |