# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s from typing import List import torch import torch.nn as nn from torch.export import Dim from torch._dynamo.backends.common import aot_autograd from torch._functorch.aot_autograd import ( make_boxed_compiler, get_aot_graph_name, set_model_name, ) from torch_mlir import fx from torch_mlir.compiler_utils import run_pipeline_with_repro_report def run(f): print(f"{f.__name__}") print("-" * len(f.__name__)) f() print() @run # CHECK-LABEL: test_import_frozen_exported_program # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> # CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] # CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] # CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] # CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] # CHECK: return %[[mul_p]] # # Validate dialect resources exist. # CHECK: dialect_resources: # CHECK-DAG: torch_tensor_1_4_torch.float32 # CHECK-DAG: torch_tensor_3_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, # namely that free tensors (buffers) and parameters are treated as # literals and frozen. @torch._dynamo.assume_constant_result def get_a(): return torch.randn(1, 4) class Basic(nn.Module): def __init__(self): super().__init__() self.b = torch.randn(3, 1) self.p = nn.Parameter(torch.randn(1, 1)) def forward(self, x): return torch.tanh(x) * get_a() * self.b * self.p m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m) @run # CHECK-LABEL: test_import_frozen_exported_program_with_func_name # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> def test_import_frozen_exported_program_with_func_name(): @torch._dynamo.assume_constant_result def get_a(): return torch.randn(1, 4) class Basic(nn.Module): def __init__(self): super().__init__() self.b = torch.randn(3, 1) self.p = nn.Parameter(torch.randn(1, 1)) def forward(self, x): return torch.tanh(x) * get_a() * self.b * self.p m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") print(m) @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> # CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.tanh(x) batch = Dim("batch", max=10) dynamic_shapes = {"x": {0: batch}} m = fx.export_and_import( Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net", import_symbolic_shape_expressions=True, ) print(m) @run # CHECK-LABEL: test_broadcast_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> # CHECK: torch.aten.size.int # CHECK: torch.prim.ListConstruct # CHECK: %[[EXPAND:.*]] = torch.aten.expand # CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> def test_broadcast_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.broadcast_to(x, (y.shape[0], -1)) # Sample inputs x = torch.randn(1, 2) y = torch.randn(10) dim_0 = Dim("dim_0", max=10) dynamic_shapes = { "x": {}, "y": {0: dim_0}, } m = fx.export_and_import( Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net", import_symbolic_shape_expressions=True, ) print(m) @make_boxed_compiler def fx_import_aot_autograd_backend( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] ): print(gm.print_readable(False), flush=True) m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name()) print(m, flush=True) return gm @run # CHECK-LABEL: test_stateless_fx_import # CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> def test_stateless_fx_import(): fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend) set_model_name("basic_forward") @torch._dynamo.optimize(backend=fx_import_backend) def basic_forward(x): return torch.tanh(x) basic_forward(torch.randn(3, 4)) @run # CHECK-LABEL: test_full # CHECK: %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[],i1>, !torch.int -> !torch.vtensor<[],i1> def test_full(): class Basic(nn.Module): def __init__(self): super().__init__() def forward(self): return torch.full( [], False, dtype=torch.bool, layout=torch.strided, device="cpu", pin_memory=False, ) m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True) run_pipeline_with_repro_report( m, f"builtin.module(torch-simplification-pipeline)", "torch-simplification-pipeline", ) print(m)