torch-mlir/python/test/dynamo_fx_importer/basic.py

117 lines
5.2 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
from typing import List
import torch
import torch.fx
import torch._dynamo as dynamo
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_compilation_context, set_model_name
from torch_mlir.compiler_utils import TorchMlirCompilerError
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir_e2e_test.configs.torchdynamo import jit
@make_boxed_compiler
def my_aot_autograd_backend(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
print(gm.graph)
*_, model_name, nth_graph = get_aot_compilation_context()
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
print(mlir_module.operation.get_asm(enable_debug_info=True))
return gm
my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend)
# CHECK: module attributes {torch.debug_module_name = "basic"} {
# CHECK-NEXT: func.func @basic(%[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK-NEXT: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> loc(#[[LOC:.*]])
# CHECK-NEXT: return %[[TANH]] : !torch.vtensor<[3,4],f32>
# CHECK-NEXT: }
# CHECK-NEXT: }
# CHECK-NEXT: #[[LOC]] = loc("{{.*}}/dynamo_fx_importer/basic.py":{{[0-9]+}}:{{[0-9]+}})
@dynamo.optimize(my_backend)
def basic(x):
return torch.tanh(x)
set_model_name("basic")
basic(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_list_device_int_none_dtype() -> !torch.vtensor<[3,4],f16> {
# CHECK: %[[INT3:.*]] = torch.constant.int 3
# CHECK: %[[INT4:.*]] = torch.constant.int 4
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
# CHECK: %[[INT5:.*]] = torch.constant.int 5
# CHECK: %[[NONE0:.*]] = torch.constant.none
# CHECK: %[[DEVICE_CPU:.*]] = torch.constant.device "cpu"
# CHECK: %[[NONE1:.*]] = torch.constant.none
# CHECK: %[[RANDN:.*]] = torch.aten.randn %[[LIST]], %[[INT5]], %[[NONE0]], %[[DEVICE_CPU]], %[[NONE1]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.none -> !torch.vtensor<[3,4],f16>
# CHECK: return %[[RANDN]] : !torch.vtensor<[3,4],f16>
@dynamo.optimize(my_backend)
def literals_list_device_int_none_dtype():
return torch.ops.aten.randn([3, 4],
device=torch.device("cpu"),
dtype=torch.float16)
set_model_name("literals_list_device_int_none_dtype")
literals_list_device_int_none_dtype()
# CHECK-LABEL: func.func @literals_bool(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[NONE0:.*]] = torch.constant.none
# CHECK: %[[NONE1:.*]] = torch.constant.none
# CHECK: %[[NONE2:.*]] = torch.constant.none
# CHECK: %[[BOOL_FALSE:.*]] = torch.constant.bool false
# CHECK: %[[NONE3:.*]] = torch.constant.none
# CHECK: %[[EMPTY_LIKE:.*]] = torch.aten.empty_like %[[ARG0]], %[[NONE0]], %[[NONE1]], %[[NONE2]], %[[BOOL_FALSE]], %[[NONE3]] : !torch.vtensor<[3,4],f32>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[EMPTY_LIKE]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_bool(x):
return torch.ops.aten.empty_like(x, pin_memory=False)
set_model_name("literals_bool")
literals_bool(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_float(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
# CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[ARG0]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[UNIFORM]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_float(x):
return torch.ops.aten.uniform(x, 0.0, 1.0)
set_model_name("literals_float")
literals_float(torch.randn(3, 4))
# CHECK-LABEL: func.func @literals_str(
# CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> loc(unknown)) -> !torch.vtensor<[3,4],f32> {
# CHECK: %[[STR_TANH:.*]] = torch.constant.str "tanh"
# CHECK: %[[GELU:.*]] = torch.aten.gelu %[[ARG0]], %[[STR_TANH]] : !torch.vtensor<[3,4],f32>, !torch.str -> !torch.vtensor<[3,4],f32>
# CHECK: return %[[GELU]] : !torch.vtensor<[3,4],f32>
@dynamo.optimize(my_backend)
def literals_str(x):
return torch.ops.aten.gelu(x, approximate="tanh")
set_model_name("literals_str")
literals_str(torch.randn(3, 4))