2020-09-26 09:13:16 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
2020-10-10 07:21:01 +08:00
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
2020-09-26 09:13:16 +08:00
|
|
|
|
2020-10-02 09:59:58 +08:00
|
|
|
# TODO: Once stabilized, expand tests to include all argument dtypes.
|
|
|
|
|
2020-09-26 09:13:16 +08:00
|
|
|
import torch
|
2020-10-13 12:39:48 +08:00
|
|
|
import torch_mlir
|
2020-09-26 09:13:16 +08:00
|
|
|
|
2020-10-06 14:21:21 +08:00
|
|
|
t0 = torch.randn((1,4))
|
|
|
|
t1 = torch.randn((4,1))
|
2020-09-26 09:13:16 +08:00
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
mb = torch_mlir.ModuleBuilder()
|
2020-10-06 14:21:21 +08:00
|
|
|
with mb.capture_function("foobar", [t0, t1]) as f:
|
2020-09-26 09:13:16 +08:00
|
|
|
result = t0 + t1
|
2020-10-06 14:21:21 +08:00
|
|
|
f.returns([result])
|
2020-09-26 09:13:16 +08:00
|
|
|
|
2020-09-29 09:36:00 +08:00
|
|
|
# CHECK: module {
|
2020-10-08 01:14:34 +08:00
|
|
|
# CHECK: func @foobar(%arg0: !numpy.ndarray<[1,4]:f32>, %arg1: !numpy.ndarray<[4,1]:f32>) -> !numpy.ndarray<[4,4]:f32> {
|
2020-10-06 14:21:21 +08:00
|
|
|
# CHECK: %c1_i64 = constant 1 : i64
|
|
|
|
# CHECK: %0 = torch.kernel_call "aten::add" %arg0, %arg1, %c1_i64 : (!numpy.ndarray<[1,4]:f32>, !numpy.ndarray<[4,1]:f32>, i64) -> !numpy.ndarray<[4,4]:f32>
|
|
|
|
# CHECK: return %0 : !numpy.ndarray<[4,4]:f32>
|
2020-09-29 09:36:00 +08:00
|
|
|
# CHECK: }
|
|
|
|
# CHECK: }
|
2020-10-13 12:39:48 +08:00
|
|
|
print(mb.module)
|
2020-09-26 09:13:16 +08:00
|
|
|
|
|
|
|
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
|
2020-10-06 14:21:21 +08:00
|
|
|
for line in f.get_debug_log(): print(line)
|