mirror of https://github.com/llvm/torch-mlir
Test custom op import with symbolic shapes (#3431)
Tests the basic constructs of registering a custom op and its abstract implementations (with FakeTensors) in python, going through TorchDynamo export, followed by importing the shape expressions in the Torch dialect. Also fixes the importer were previously the symbolic bind op insertion was not gated in one place.pull/3441/head
parent
5bc626465b
commit
7e0e23c668
|
@ -1445,7 +1445,8 @@ class GraphNodeImporter:
|
|||
operands = [self._import_argument(loc, arg) for arg in node.args[0]]
|
||||
func_dialect.ReturnOp(operands, loc=loc)
|
||||
|
||||
self._create_bind_symbolic_shape_ops(loc, node)
|
||||
if import_symbolic_shape_expressions:
|
||||
self._create_bind_symbolic_shape_ops(loc, node)
|
||||
|
||||
def _promote_symbolic_scalar_int_float(self, loc, graph, param):
|
||||
temp_target = torch.ops.aten.Float.Scalar
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.export import Dim
|
||||
from torch.library import Library, impl, impl_abstract
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
print(f"{f.__name__}")
|
||||
print("-" * len(f.__name__))
|
||||
f()
|
||||
print()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op
|
||||
# CHECK: func.func @main(
|
||||
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
||||
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
|
||||
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
|
||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
|
||||
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
|
||||
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
|
||||
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
|
||||
# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32>
|
||||
def test_tanh_sigmoid_cat_custom_op():
|
||||
|
||||
m = Library("my_custom_library", "DEF")
|
||||
m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor")
|
||||
|
||||
@impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd")
|
||||
def custom_op(x, y, z):
|
||||
a = torch.tanh(x)
|
||||
b = torch.sigmoid(y)
|
||||
return torch.cat((a, a, b, z), dim=1)
|
||||
|
||||
@impl_abstract("my_custom_library::tanh_sigmoid_cat_op")
|
||||
def custom_op_meta(x, y, z):
|
||||
result = custom_op(x, y, z)
|
||||
return torch.empty_like(result)
|
||||
|
||||
class TanhSigmoidCatCustomOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z)
|
||||
|
||||
# Sample inputs
|
||||
x = torch.randn(5, 2, 3)
|
||||
y = torch.randn(5, 6, 3)
|
||||
z = torch.randn(5, 4, 3)
|
||||
|
||||
# Dynamic dim constraints
|
||||
dim_n = Dim("n", min=5, max=10)
|
||||
dim_x1 = Dim("x1", max=100)
|
||||
dim_y1 = Dim("y1", max=50)
|
||||
dim_z1 = Dim("z1")
|
||||
dynamic_shapes = {
|
||||
"x": {0: dim_n, 1: dim_x1},
|
||||
"y": {0: dim_n, 1: dim_y1},
|
||||
"z": {0: dim_n, 1: dim_z1},
|
||||
}
|
||||
|
||||
m = fx.export_and_import(
|
||||
TanhSigmoidCatCustomOp(),
|
||||
x,
|
||||
y,
|
||||
z,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
import_symbolic_shape_expressions=True,
|
||||
)
|
||||
print(m)
|
Loading…
Reference in New Issue