torch-mlir/test/python/fx_importer/basic_test.py

206 lines
7.1 KiB
Python
Raw Normal View History

[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
2024-03-22 05:44:54 +08:00
from typing import List
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
import torch
import torch.nn as nn
from torch.export import Dim
2024-03-22 05:44:54 +08:00
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import (
make_boxed_compiler,
get_aot_graph_name,
set_model_name,
)
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
from torch_mlir import fx
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]]
# CHECK: return %[[mul_p]]
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
# namely that free tensors (buffers) and parameters are treated as
# literals and frozen.
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = fx.export_and_import(Basic(), torch.randn(3, 4))
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
print(m)
@run
# CHECK-LABEL: test_import_frozen_exported_program_with_func_name
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
def test_import_frozen_exported_program_with_func_name():
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
print(m)
@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.tanh(x)
batch = Dim("batch", max=10)
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
Basic(),
torch.randn(3, 4),
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
)
print(m)
2024-03-22 05:44:54 +08:00
@run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: torch.aten.size.int
# CHECK: torch.prim.ListConstruct
# CHECK: %[[EXPAND:.*]] = torch.aten.expand
# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
dim_0 = Dim("dim_0", max=10)
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
Basic(),
x,
y,
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
)
print(m)
2024-03-22 05:44:54 +08:00
@make_boxed_compiler
def fx_import_aot_autograd_backend(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
2024-03-22 05:44:54 +08:00
print(gm.print_readable(False), flush=True)
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name())
print(m, flush=True)
return gm
2024-03-22 05:44:54 +08:00
@run
# CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
2024-03-22 05:44:54 +08:00
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import():
fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
set_model_name("basic_forward")
2024-03-22 05:44:54 +08:00
@torch._dynamo.optimize(backend=fx_import_backend)
def basic_forward(x):
return torch.tanh(x)
basic_forward(torch.randn(3, 4))
@run
# CHECK-LABEL: test_full
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
def test_full():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch.full(
[],
False,
dtype=torch.bool,
layout=torch.strided,
device="cpu",
pin_memory=False,
)
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
run_pipeline_with_repro_report(
m,
f"builtin.module(torch-simplification-pipeline)",
"torch-simplification-pipeline",
)
print(m)