2022-05-30 18:04:24 +08:00
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
2022-09-29 02:42:34 +08:00
2022-05-30 18:04:24 +08:00
class TanhModule ( torch . nn . Module ) :
def forward ( self , x ) :
return torch . ops . aten . tanh ( x )
tanh_example_input = torch . ones ( 2 , 3 )
# Simplest case: One example argument.
print ( torch_mlir . compile ( TanhModule ( ) , tanh_example_input , use_tracing = True ) )
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a tuple.
print ( torch_mlir . compile ( TanhModule ( ) , ( tanh_example_input , ) , use_tracing = True ) )
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a list.
print ( torch_mlir . compile ( TanhModule ( ) , [ tanh_example_input ] , use_tracing = True ) )
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
2022-08-04 07:30:57 +08:00
# TensorPlaceholder support.
placeholder = torch_mlir . TensorPlaceholder . like (
tanh_example_input , dynamic_axes = [ 1 ] )
print ( torch_mlir . compile ( TanhModule ( ) , [ placeholder ] ,
use_tracing = True , ignore_traced_shapes = True ) )
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
try :
# CHECK: `ignore_traced_shapes` requires `use_tracing`
torch_mlir . compile ( TanhModule ( ) , [ placeholder ] , ignore_traced_shapes = True )
except Exception as e :
print ( e )
try :
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
torch_mlir . compile ( TanhModule ( ) , [ placeholder ] , use_tracing = True )
except Exception as e :
print ( e )
2022-09-29 02:42:34 +08:00
class DictModule ( torch . nn . Module ) :
def forward ( self , x ) :
return x [ ' a ' ] * 2.0
try :
2022-11-09 19:15:54 +08:00
# CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
2022-09-29 02:42:34 +08:00
torch_mlir . compile ( DictModule ( ) , { ' a ' : torch . tensor ( 3.0 ) } , use_tracing = True )
except Exception as e :
print ( e )
try :
2022-11-09 19:15:54 +08:00
# CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
2022-09-29 02:42:34 +08:00
torch_mlir . compile ( DictModule ( ) , [ { ' a ' : torch . tensor ( 3.0 ) } ] , use_tracing = True )
except Exception as e :
2022-11-09 19:15:54 +08:00
print ( e )