mirror of https://github.com/llvm/torch-mlir
87 lines
3.4 KiB
Python
87 lines
3.4 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.export import Dim
|
|
from torch.library import Library, impl, impl_abstract
|
|
|
|
from torch_mlir import fx
|
|
|
|
|
|
def run(f):
|
|
print(f"{f.__name__}")
|
|
print("-" * len(f.__name__))
|
|
f()
|
|
print()
|
|
|
|
|
|
@run
|
|
# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op
|
|
# CHECK: func.func @main(
|
|
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
|
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
|
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
|
|
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
|
|
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
|
|
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
|
|
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
|
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
|
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
|
# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32>
|
|
# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
|
|
# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32>
|
|
def test_tanh_sigmoid_cat_custom_op():
|
|
|
|
m = Library("my_custom_library", "DEF")
|
|
m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor")
|
|
|
|
@impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd")
|
|
def custom_op(x, y, z):
|
|
a = torch.tanh(x)
|
|
b = torch.sigmoid(y)
|
|
return torch.cat((a, a, b, z), dim=1)
|
|
|
|
@impl_abstract("my_custom_library::tanh_sigmoid_cat_op")
|
|
def custom_op_meta(x, y, z):
|
|
result = custom_op(x, y, z)
|
|
return torch.empty_like(result)
|
|
|
|
class TanhSigmoidCatCustomOp(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, y, z):
|
|
return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z)
|
|
|
|
# Sample inputs
|
|
x = torch.randn(5, 2, 3)
|
|
y = torch.randn(5, 6, 3)
|
|
z = torch.randn(5, 4, 3)
|
|
|
|
# Dynamic dim constraints
|
|
dim_n = Dim("n", min=5, max=10)
|
|
dim_x1 = Dim("x1", max=100)
|
|
dim_y1 = Dim("y1", max=50)
|
|
dim_z1 = Dim("z1")
|
|
dynamic_shapes = {
|
|
"x": {0: dim_n, 1: dim_x1},
|
|
"y": {0: dim_n, 1: dim_y1},
|
|
"z": {0: dim_n, 1: dim_z1},
|
|
}
|
|
|
|
m = fx.export_and_import(
|
|
TanhSigmoidCatCustomOp(),
|
|
x,
|
|
y,
|
|
z,
|
|
dynamic_shapes=dynamic_shapes,
|
|
import_symbolic_shape_expressions=True,
|
|
)
|
|
print(m)
|