From 7e0e23c66820d1db548103acbdf1337f701dc5a3 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 9 Jun 2024 00:32:49 -0700 Subject: [PATCH] Test custom op import with symbolic shapes (#3431) Tests the basic constructs of registering a custom op and its abstract implementations (with FakeTensors) in python, going through TorchDynamo export, followed by importing the shape expressions in the Torch dialect. Also fixes the importer were previously the symbolic bind op insertion was not gated in one place. --- python/torch_mlir/extras/fx_importer.py | 3 +- test/python/fx_importer/custom_op_test.py | 86 +++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 test/python/fx_importer/custom_op_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 16c27c0fa..2a73325c7 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1445,7 +1445,8 @@ class GraphNodeImporter: operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) - self._create_bind_symbolic_shape_ops(loc, node) + if import_symbolic_shape_expressions: + self._create_bind_symbolic_shape_ops(loc, node) def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py new file mode 100644 index 000000000..dbbc5ba05 --- /dev/null +++ b/test/python/fx_importer/custom_op_test.py @@ -0,0 +1,86 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1") + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m)