torch-mlir/test/python/fx_importer/basic_test.py

81 lines
2.8 KiB
Python

# Copyright 2023 Advanced Micro Devices, Inc
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_1_torch.float32> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]]
# CHECK: return %[[mul_p]]
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
# CHECK-DAG: torch_tensor_1_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
# namely that free tensors (buffers) and parameters are treated as
# literals and frozen.
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = export_and_import(Basic(), torch.randn(3, 4))
print(m)