mirror of https://github.com/llvm/torch-mlir
43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
|
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||
|
|
||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
|
||
|
import numpy as np
|
||
|
import npcomp as npc
|
||
|
from npcomp.types import *
|
||
|
|
||
|
|
||
|
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
||
|
return a.T
|
||
|
|
||
|
|
||
|
def transpose(a: np.ndarray) -> np.ndarray:
|
||
|
return np.transpose(a)
|
||
|
|
||
|
|
||
|
# TODO: Implement subclassing and deriving constraints by run
|
||
|
exp = npc.Exporter()
|
||
|
exp.transpose_attribute = transpose_attribute
|
||
|
exp.transpose = transpose
|
||
|
|
||
|
mb = npc.tracing.ModuleBuilder()
|
||
|
mb.trace(exp.transpose_attribute, exp.transpose)
|
||
|
|
||
|
# TODO: Consolidate any_dtype -> UnknownType.
|
||
|
# CHECK-LABEL: func @transpose_attribute(
|
||
|
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||
|
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
||
|
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
||
|
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
||
|
# CHECK: }
|
||
|
|
||
|
# CHECK-LABEL: func @transpose(
|
||
|
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||
|
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
||
|
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
||
|
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
||
|
# CHECK: }
|
||
|
print(mb.module)
|