# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s import torch from framework import run_test from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function # CHECK: graph(%[[A1:.*]] : Tensor, # CHECK: %[[A2:.*]] : Tensor, # CHECK: %[[A3:.*]] : Tensor): # CHECK: %[[A4:.*]] : int = prim::Constant[value=1]() # CHECK: %[[A5:.*]] : int = prim::Constant[value=1]() # CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]]) # CHECK: return (%[[A0]]) # ----- # CHECK: PASS - simple @run_test def simple(): target = torch.ops.aten.addmm.default A = torch.randn(1, 3, 32, 32) B = torch.randn(1, 3, 32, 32) C = torch.randn(1, 3, 32, 32) args = (A, B, C) kwargs = dict(beta=1, alpha=1) script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: graph(%[[B1:.*]] : Tensor, # CHECK: %[[B2:.*]] : Tensor, # CHECK: %[[B3:.*]] : Tensor): # CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[B9:.*]] : int = prim::Constant[value=1]() # CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]]) # CHECK: return (%[[B0]]) # ----- # CHECK: PASS - handle_optional_tensor_input @run_test def handle_optional_tensor_input(): target = torch.ops.aten.convolution.default input = torch.randn(1, 3, 32, 32) weight = torch.randn(3, 3, 3, 3) bias = torch.randn(3) args = (input, weight, bias) kwargs = dict( stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], groups=1, ) script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: FAIL - fail_not_enough_args # CHECK: Errors: tuple index out of range @run_test def fail_not_enough_args(): target = torch.ops.aten.convolution.default input = torch.randn(1, 3, 32, 32) weight = torch.randn(3, 3, 3, 3) bias = torch.randn(3) args = (input, weight, bias) kwargs = dict( stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], # Missing groups=1, ) build_script_function(target._schema, args, kwargs) # CHECK: PASS - simple_args_or_kwargs @run_test def simple_args_or_kwargs(): target = torch.ops.aten.convolution.default input = torch.randn(1, 3, 32, 32) weight = torch.randn(3, 3, 3, 3) bias = torch.randn(3) stride = [1, 1] padding = [0, 0] dilation = [1, 1] transposed = False output_padding = [0, 0] groups = 1 script_fun1 = build_script_function( target._schema, (input, weight, bias), dict( stride=stride, padding=padding, dilation=dilation, transposed=transposed, output_padding=output_padding, groups=groups, ), ) script_fun2 = build_script_function( target._schema, (input, weight, bias, stride, padding, dilation), dict(transposed=transposed, output_padding=output_padding, groups=groups), ) assert str(script_fun1.graph) == str(script_fun2.graph) # CHECK: graph(%[[C2:.*]] : Tensor): # CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]() # CHECK: %[[C4:.*]] : NoneType = prim::Constant() # CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]]) # CHECK: return (%[[C0]], %[[C1]]) # ----- # CHECK: PASS - handle_empty_lists @run_test def handle_empty_lists(): target = torch.ops.aten.max_pool2d_with_indices.default # print(target._schema) args = (torch.randn((1, 3, 32, 32)),) kwargs = { "kernel_size": [3, 3], "stride": [], "padding": [0, 0], "dilation": [1, 1], "ceil_mode": False, } script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: graph(%[[D2:.*]] : Tensor): # CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]() # CHECK: %[[D4:.*]] : NoneType = prim::Constant() # CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]]) # CHECK: return (%[[D0]], %[[D1]]) # ----- # CHECK: PASS - handle_nones @run_test def handle_nones(): target = torch.ops.aten.max_pool2d_with_indices.default # print(target._schema) args = (torch.randn((1, 3, 32, 32)),) kwargs = { "kernel_size": [3, 3], "stride": None, "padding": [0, 0], "dilation": [1, 1], "ceil_mode": False, } script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: graph(%[[E1:.*]] : Tensor, # CHECK: %[[E2:.*]] : Tensor, # CHECK: %[[E3:.*]] : Tensor): # CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]() # CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]() # CHECK: %[[E9:.*]] : int = prim::Constant[value=1]() # CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]]) # CHECK: return (%[[E0]]) # ----- # CHECK: PASS - handle_optional_tensors @run_test def handle_optional_tensors(): target = torch.ops.aten.convolution.default input = torch.randn(1, 3, 32, 32) weight = torch.randn(3, 3, 3, 3) bias = torch.randn(3) args = (input, weight, bias) kwargs = dict( stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], groups=1, ) script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: graph(%[[F1:.*]] : Tensor): # CHECK: %[[F2:.*]] : NoneType = prim::Constant() # CHECK: %[[F3:.*]] : NoneType = prim::Constant() # CHECK: %[[F4:.*]] : NoneType = prim::Constant() # CHECK: %[[F5:.*]] : NoneType = prim::Constant() # CHECK: %[[F6:.*]] : NoneType = prim::Constant() # CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]]) # CHECK: return (%[[F0]]) # ----- # CHECK: PASS - handle_ones_like @run_test def handle_ones_like(): target = torch.ops.aten.ones_like.default input = torch.randn(1, 3, 32, 32) args = (input,) kwargs = dict( dtype=None, layout=None, device=None, pin_memory=None, memory_format=None ) script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: graph(%[[G3:.*]] : Tensor, # CHECK: %[[G4:.*]] : Tensor, # CHECK: %[[G5:.*]] : Tensor): # CHECK: %[[G6:.*]] : NoneType = prim::Constant() # CHECK: %[[G7:.*]] : NoneType = prim::Constant() # CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]() # CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]() # CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]]) # CHECK: return (%[[G0]], %[[G1]], %[[G2]]) # ----- # CHECK: PASS - handle_multiple_outputs @run_test def handle_multiple_outputs(): target = torch.ops.aten.native_batch_norm.default A = torch.randn(1, 3, 32, 32) B = torch.randn(1, 3, 32, 32) C = torch.randn(1, 3, 32, 32) args = (A, B, C, None, None, False, 1.0, 1.0) kwargs = dict() script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph) # CHECK: f # CHECK: PASS - check_legal_name @run_test def check_legal_name(): target = torch.ops.aten.native_batch_norm.default A = torch.randn(1, 3, 32, 32) B = torch.randn(1, 3, 32, 32) C = torch.randn(1, 3, 32, 32) args = (A, B, C, None, None, False, 1.0, 1.0) kwargs = dict() script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.name) # CHECK: graph(%[[H3:.*]] : Tensor, # CHECK: %[[H4:.*]] : Tensor, # CHECK: %[[H5:.*]] : Tensor, # CHECK: %[[H6:.*]] : Tensor, # CHECK: %[[H7:.*]] : Tensor, # CHECK: %out : Tensor, # CHECK: %save_mean : Tensor, # CHECK: %save_invstd : Tensor): # CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]() # CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]() # CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]() # CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd) # CHECK: return (%[[H0]], %[[H1]], %[[H2]]) # ----- # CHECK: PASS - correctly_order_kwargs @run_test def correctly_order_kwargs(): target = torch.ops.aten.native_batch_norm.out input = torch.randn(2, 5, 2, 3) weight = torch.randn(5) bias = torch.randn(5) running_mean = torch.randn(5) running_var = torch.randn(5) args = (input, weight, bias, running_mean, running_var) out = torch.empty_like(input) save_mean = torch.empty_like(running_mean) save_invstd = torch.empty_like(running_var) kwargs = dict( training=False, momentum=0.1, eps=0.0001, out=out, save_mean=save_mean, save_invstd=save_invstd, ) script_fun = build_script_function(target._schema, args, kwargs) print(script_fun.graph)