torch-mlir/test/Dialect/Torch
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
..
GlobalizeObjectGraph [torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) 2024-01-25 14:24:13 -08:00
adjust-calling-conventions.mlir Clean up verification of calling conventions. 2023-07-20 20:08:46 +02:00
canonicalize.mlir Representing Symbolic Shape Expressions in Torch Dialect (#3372) 2024-06-07 04:04:03 -07:00
decompose-complex-ops-legal.mlir handles 2,3,4 from https://github.com/llvm/torch-mlir/issues/1963 (#1964) 2023-03-24 21:50:01 -05:00
decompose-complex-ops.mlir [Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350) 2024-05-22 17:19:08 +00:00
drop-abstract-interp-calculations.mlir [custom op] Generalize shape library logic to work with dtypes (#1594) 2022-12-13 08:25:41 -08:00
erase-module-initializer.mlir Iteratively run the main simplification pipeline. 2022-08-17 14:54:33 -07:00
fuse-quantized-ops.mlir Generalize Operand Quantization in FuseQuantizeOps (#3327) 2024-05-12 20:49:59 -07:00
inline-global-slots-analysis.mlir Rework how global slot initializers work. 2022-08-08 18:12:06 -07:00
inline-global-slots-transform.mlir Rework how global slot initializers work. 2022-08-08 18:12:06 -07:00
invalid.mlir Representing Symbolic Shape Expressions in Torch Dialect (#3372) 2024-06-07 04:04:03 -07:00
lower-to-backend-contract-error.mlir Allow running DecomposeComplexOps more than once (#1671) 2022-12-08 09:26:38 -08:00
match-quantized-customs-ops.mlir [torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) 2024-01-25 14:24:13 -08:00
maximize-value-semantics.mlir Add alias analysis for cast-like ops to maximize-value-semantics (#2160) 2023-05-25 17:05:41 +00:00
ops.mlir [NFC reformat] Run pre-commit on all files and format misc. 2024-04-27 14:08:09 -07:00
prepare-for-globalize-object-graph.mlir mlir: bump llvm tag to 5380e3 (#856) 2022-05-16 12:54:35 -07:00
reduce-op-variants-error.mlir mlir: bump llvm tag to 5380e3 (#856) 2022-05-16 12:54:35 -07:00
reduce-op-variants.mlir [torch] Fix tm_tensor.attention for end-to-end (#2907) 2024-02-13 21:18:01 -08:00
refine-public-return.mlir Support `DerefineOp` in `RefinePublicReturn`. 2023-07-20 20:08:46 +02:00
reify-dtype-calculations.mlir Breakup python pytorch deps (#2582) 2023-11-19 12:10:19 -08:00
reify-shape-calculations.mlir Cast `number` to `float` when shape function takes Scalar arg (#1978) 2023-03-28 09:30:31 -07:00
scalarize-shapes.mlir [torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055) 2024-03-26 12:41:40 -07:00
simplify-dtype-calculations.mlir [MLIR][TORCH] Add E2E support for view_as_real op (#2419) 2023-09-01 21:12:01 -07:00
simplify-shape-calculations.mlir [torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) 2024-01-25 14:24:13 -08:00
torch-function-to-torch-backend-pipeline.mlir [Torch Dialect] fix torch.uint8's dtype infer (#2227) 2023-06-13 10:38:20 +08:00
torch-nary-canonicalize.mlir [torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878) 2024-02-19 10:28:23 -08:00
verify-backend-contract-error.mlir Clean up verification of calling conventions. 2023-07-20 20:08:46 +02:00
verify-backend-contract-unimplemented-op.mlir LowerToBackendContract: Explicitly error out on unimplemented operator (#1947) 2023-03-20 16:27:08 +01:00