torch-mlir/test/python/fx_importer/basic_test.py

81 lines
2.8 KiB
Python
Raw Normal View History

[fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma <aviator1994@gmail.com> Co-authored-by: Arham Khan <arhammkhan@gmail.com> Co-authored-by: brucekimrokcmu <kwangkyk@alumni.cmu.edu> Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com>
2023-12-22 00:40:10 +08:00
# Copyright 2023 Advanced Micro Devices, Inc
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_1_torch.float32> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]]
# CHECK: return %[[mul_p]]
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
# CHECK-DAG: torch_tensor_1_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
# namely that free tensors (buffers) and parameters are treated as
# literals and frozen.
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = export_and_import(Basic(), torch.randn(3, 4))
print(m)