torch-mlir/test/python/fx_importer/basic_test.py

186 lines
5.9 KiB
Python
Raw Normal View History

[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
2024-03-22 05:44:54 +08:00
from typing import List
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
import torch
import torch.nn as nn
from torch.export import Dim
2024-03-22 05:44:54 +08:00
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import (
make_boxed_compiler,
get_aot_graph_name,
set_model_name,
)
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
from torch_mlir import fx
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]]
# CHECK: return %[[mul_p]]
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
# namely that free tensors (buffers) and parameters are treated as
# literals and frozen.
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = fx.export_and_import(Basic(), torch.randn(3, 4))
[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
print(m)
@run
# CHECK-LABEL: test_import_frozen_exported_program_with_func_name
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
def test_import_frozen_exported_program_with_func_name():
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
print(m)
@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.tanh(x)
batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net"
)
print(m)
2024-03-22 05:44:54 +08:00
@run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
dim_0 = Dim("dim_0")
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
)
print(m)
2024-03-22 05:44:54 +08:00
@make_boxed_compiler
def fx_import_aot_autograd_backend(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
):
2024-03-22 05:44:54 +08:00
print(gm.print_readable(False), flush=True)
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name())
print(m, flush=True)
return gm
2024-03-22 05:44:54 +08:00
@run
# CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
2024-03-22 05:44:54 +08:00
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import():
fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
set_model_name("basic_forward")
2024-03-22 05:44:54 +08:00
@torch._dynamo.optimize(backend=fx_import_backend)
def basic_forward(x):
return torch.tanh(x)
basic_forward(torch.randn(3, 4))
@run
# CHECK-LABEL: test_full
# CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1>
def test_full():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return torch.full(
[],
False,
dtype=torch.bool,
layout=torch.strided,
device="cpu",
pin_memory=False,
)
m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
run_pipeline_with_repro_report(
m,
f"builtin.module(torch-simplification-pipeline)",
"torch-simplification-pipeline",
)
print(m)