2023-12-22 00:40:10 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
|
2024-03-22 05:44:54 +08:00
|
|
|
from typing import List
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2024-03-15 01:26:34 +08:00
|
|
|
from torch.export import Dim
|
2024-03-22 05:44:54 +08:00
|
|
|
from torch._dynamo.backends.common import aot_autograd
|
2024-04-28 05:16:31 +08:00
|
|
|
from torch._functorch.aot_autograd import (
|
|
|
|
make_boxed_compiler,
|
|
|
|
get_aot_graph_name,
|
|
|
|
set_model_name,
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
from torch_mlir import fx
|
2024-04-16 14:14:19 +08:00
|
|
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def run(f):
|
|
|
|
print(f"{f.__name__}")
|
|
|
|
print("-" * len(f.__name__))
|
|
|
|
f()
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_import_frozen_exported_program
|
|
|
|
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
|
|
|
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
|
|
|
|
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
|
2024-02-06 14:19:31 +08:00
|
|
|
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
2023-12-22 00:40:10 +08:00
|
|
|
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
|
|
|
|
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
|
|
|
|
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
|
|
|
|
# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]]
|
|
|
|
# CHECK: return %[[mul_p]]
|
|
|
|
#
|
|
|
|
# Validate dialect resources exist.
|
|
|
|
# CHECK: dialect_resources:
|
|
|
|
# CHECK-DAG: torch_tensor_1_4_torch.float32
|
|
|
|
# CHECK-DAG: torch_tensor_3_1_torch.float32
|
|
|
|
def test_import_frozen_exported_program():
|
|
|
|
# Tests the basic structural premises of import_frozen_exported_program,
|
|
|
|
# namely that free tensors (buffers) and parameters are treated as
|
|
|
|
# literals and frozen.
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
|
|
def get_a():
|
|
|
|
return torch.randn(1, 4)
|
|
|
|
|
|
|
|
class Basic(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.b = torch.randn(3, 1)
|
|
|
|
self.p = nn.Parameter(torch.randn(1, 1))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.tanh(x) * get_a() * self.b * self.p
|
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
2023-12-22 00:40:10 +08:00
|
|
|
print(m)
|
2024-02-27 02:08:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_import_frozen_exported_program_with_func_name
|
|
|
|
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
|
|
|
def test_import_frozen_exported_program_with_func_name():
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
|
|
def get_a():
|
|
|
|
return torch.randn(1, 4)
|
|
|
|
|
|
|
|
class Basic(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.b = torch.randn(3, 1)
|
|
|
|
self.p = nn.Parameter(torch.randn(1, 1))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.tanh(x) * get_a() * self.b * self.p
|
|
|
|
|
|
|
|
m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
|
|
|
|
print(m)
|
2024-03-15 01:26:34 +08:00
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2024-03-15 01:26:34 +08:00
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
|
|
|
|
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
|
|
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
|
|
|
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
|
|
|
|
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
|
|
|
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
|
2024-03-15 01:26:34 +08:00
|
|
|
def test_import_frozen_exported_program_with_dynamic_shapes():
|
|
|
|
class Basic(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.tanh(x)
|
|
|
|
|
|
|
|
batch = Dim("batch")
|
|
|
|
dynamic_shapes = {"x": {0: batch}}
|
2024-04-28 05:16:31 +08:00
|
|
|
m = fx.export_and_import(
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
Basic(),
|
|
|
|
torch.randn(3, 4),
|
|
|
|
dynamic_shapes=dynamic_shapes,
|
|
|
|
func_name="test_net",
|
|
|
|
import_symbolic_shape_expressions=True,
|
2024-04-28 05:16:31 +08:00
|
|
|
)
|
2024-03-15 01:26:34 +08:00
|
|
|
print(m)
|
2024-03-22 05:44:54 +08:00
|
|
|
|
|
|
|
|
2024-04-30 00:21:12 +08:00
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
|
|
|
|
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
|
|
|
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
|
|
|
# CHECK: torch.aten.size.int
|
|
|
|
# CHECK: torch.prim.ListConstruct
|
|
|
|
# CHECK: %[[EXPAND:.*]] = torch.aten.expand
|
|
|
|
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
|
2024-04-30 00:21:12 +08:00
|
|
|
def test_broadcast_with_dynamic_shapes():
|
|
|
|
class Basic(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
|
|
|
return torch.broadcast_to(x, (y.shape[0], -1))
|
|
|
|
|
|
|
|
# Sample inputs
|
|
|
|
x = torch.randn(1, 2)
|
|
|
|
y = torch.randn(10)
|
|
|
|
|
|
|
|
dim_0 = Dim("dim_0")
|
|
|
|
dynamic_shapes = {
|
|
|
|
"x": {},
|
|
|
|
"y": {0: dim_0},
|
|
|
|
}
|
|
|
|
|
|
|
|
m = fx.export_and_import(
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
Basic(),
|
|
|
|
x,
|
|
|
|
y,
|
|
|
|
dynamic_shapes=dynamic_shapes,
|
|
|
|
func_name="test_net",
|
|
|
|
import_symbolic_shape_expressions=True,
|
2024-04-30 00:21:12 +08:00
|
|
|
)
|
|
|
|
print(m)
|
|
|
|
|
|
|
|
|
2024-03-22 05:44:54 +08:00
|
|
|
@make_boxed_compiler
|
2024-04-28 05:16:31 +08:00
|
|
|
def fx_import_aot_autograd_backend(
|
|
|
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
|
|
):
|
2024-03-22 05:44:54 +08:00
|
|
|
print(gm.print_readable(False), flush=True)
|
|
|
|
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name())
|
|
|
|
print(m, flush=True)
|
|
|
|
return gm
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2024-03-22 05:44:54 +08:00
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_stateless_fx_import
|
2024-04-30 00:21:12 +08:00
|
|
|
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
2024-03-22 05:44:54 +08:00
|
|
|
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
|
|
|
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
|
|
|
def test_stateless_fx_import():
|
|
|
|
fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
|
|
|
|
set_model_name("basic_forward")
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2024-03-22 05:44:54 +08:00
|
|
|
@torch._dynamo.optimize(backend=fx_import_backend)
|
|
|
|
def basic_forward(x):
|
|
|
|
return torch.tanh(x)
|
|
|
|
|
|
|
|
basic_forward(torch.randn(3, 4))
|
2024-04-16 14:14:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
@run
|
|
|
|
# CHECK-LABEL: test_full
|
|
|
|
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
|
|
|
|
def test_full():
|
|
|
|
class Basic(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self):
|
2024-04-28 05:16:31 +08:00
|
|
|
return torch.full(
|
|
|
|
[],
|
|
|
|
False,
|
|
|
|
dtype=torch.bool,
|
|
|
|
layout=torch.strided,
|
|
|
|
device="cpu",
|
|
|
|
pin_memory=False,
|
|
|
|
)
|
2024-04-16 14:14:19 +08:00
|
|
|
|
|
|
|
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
|
|
|
|
run_pipeline_with_repro_report(
|
|
|
|
m,
|
|
|
|
f"builtin.module(torch-simplification-pipeline)",
|
|
|
|
"torch-simplification-pipeline",
|
|
|
|
)
|
|
|
|
print(m)
|