mirror of https://github.com/llvm/torch-mlir
44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
|
# -*- Python -*-
|
||
|
# This file is licensed under a pytorch-style license
|
||
|
# See frontends/pytorch/LICENSE for license information.
|
||
|
|
||
|
import torch
|
||
|
import npcomp.frontends.pytorch as torch_mlir
|
||
|
import npcomp.frontends.pytorch.test as test
|
||
|
|
||
|
# RUN: python %s | FileCheck %s
|
||
|
|
||
|
dev = torch_mlir.mlir_device()
|
||
|
|
||
|
x = torch.rand((3,64,8,8), device=dev)
|
||
|
y = x*x
|
||
|
print (y.stride())
|
||
|
|
||
|
dim = [64,24,24]
|
||
|
dim = [4,4,4]
|
||
|
N = 2;
|
||
|
count = dim[0]*dim[1]*dim[2]
|
||
|
sizes = (N,dim[0],dim[1],dim[2])
|
||
|
strides = (1,dim[1]*dim[2],dim[2],1)
|
||
|
print(count)
|
||
|
t0 = torch.randn((N,count), device=dev)
|
||
|
t0_like = torch.randn((N,count))
|
||
|
|
||
|
|
||
|
t1 = t0.as_strided(sizes, strides)
|
||
|
t1_ref = t0.to('cpu').as_strided(sizes, strides)
|
||
|
t1_like = t0_like.as_strided(sizes, strides)
|
||
|
|
||
|
t1_ref = t1_ref.clone()
|
||
|
|
||
|
# check that the IR has recorded the
|
||
|
# stride properly before invoking JIT
|
||
|
# CHECK: PASS! stride check
|
||
|
test.compare_eq(t1.stride(), t1_like.stride(), "stride")
|
||
|
|
||
|
# CHECK: PASS! as_stride check
|
||
|
test.compare(t1_ref, t1, "as_stride")
|
||
|
|
||
|
# CHECK: PASS! as_stride stride check
|
||
|
test.compare_eq(t1_ref.stride(), t1.to("cpu").stride(), "as_stride stride")
|