Test custom op import with symbolic shapes (#3431)

Tests the basic constructs of registering a custom op and its abstract
implementations (with FakeTensors) in python, going through TorchDynamo
export, followed by importing the shape expressions in the Torch
dialect.

Also fixes the importer were previously the symbolic bind op insertion
was not gated in one place.
pull/3441/head
Sambhav Jain 2024-06-09 00:32:49 -07:00 committed by GitHub
parent 5bc626465b
commit 7e0e23c668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 1 deletions

View File

@ -1445,7 +1445,8 @@ class GraphNodeImporter:
operands = [self._import_argument(loc, arg) for arg in node.args[0]]
func_dialect.ReturnOp(operands, loc=loc)
self._create_bind_symbolic_shape_ops(loc, node)
if import_symbolic_shape_expressions:
self._create_bind_symbolic_shape_ops(loc, node)
def _promote_symbolic_scalar_int_float(self, loc, graph, param):
temp_target = torch.ops.aten.Float.Scalar

View File

@ -0,0 +1,86 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch.nn as nn
from torch.export import Dim
from torch.library import Library, impl, impl_abstract
from torch_mlir import fx
def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()
@run
# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op
# CHECK: func.func @main(
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32>
def test_tanh_sigmoid_cat_custom_op():
m = Library("my_custom_library", "DEF")
m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor")
@impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd")
def custom_op(x, y, z):
a = torch.tanh(x)
b = torch.sigmoid(y)
return torch.cat((a, a, b, z), dim=1)
@impl_abstract("my_custom_library::tanh_sigmoid_cat_op")
def custom_op_meta(x, y, z):
result = custom_op(x, y, z)
return torch.empty_like(result)
class TanhSigmoidCatCustomOp(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y, z):
return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z)
# Sample inputs
x = torch.randn(5, 2, 3)
y = torch.randn(5, 6, 3)
z = torch.randn(5, 4, 3)
# Dynamic dim constraints
dim_n = Dim("n", min=5, max=10)
dim_x1 = Dim("x1", max=100)
dim_y1 = Dim("y1", max=50)
dim_z1 = Dim("z1")
dynamic_shapes = {
"x": {0: dim_n, 1: dim_x1},
"y": {0: dim_n, 1: dim_y1},
"z": {0: dim_n, 1: dim_z1},
}
m = fx.export_and_import(
TanhSigmoidCatCustomOp(),
x,
y,
z,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)