From 6f94c7b0aadeee0138f928d46cdc96d2f7b42023 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 14 Jun 2024 23:59:08 +0800 Subject: [PATCH] [Torch] Add support for Meshgrid (#3462) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 48 ++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 14 +++ .../Torch/Transforms/RecomposeComplexOps.cpp | 76 ++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 ++ .../build_tools/torch_ods_gen.py | 2 + .../test_suite/__init__.py | 1 + .../test_suite/meshgrid.py | 88 +++++++++++++++++++ 7 files changed, 235 insertions(+) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c22f46ebe..5af6873d8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13733,6 +13733,54 @@ def Torch_AtenChunkOp : Torch_Op<"aten.chunk", [ }]; } +def Torch_AtenMeshgridOp : Torch_Op<"aten.meshgrid", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::meshgrid : (Tensor[]) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMeshgridOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenMeshgridOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMeshgridIndexingOp : Torch_Op<"aten.meshgrid.indexing", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors, + Torch_StringType:$indexing + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMeshgridIndexingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMeshgridIndexingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 61a0857a8..140549ed5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3039,6 +3039,20 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenMeshgridOp +//===----------------------------------------------------------------------===// +void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) { + Value constIndexing = rewriter.create( + op->getLoc(), rewriter.getStringAttr("ij")); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getTensors(), constIndexing); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenSplitSizesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index b930778ff..d9b2648f6 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -719,6 +719,81 @@ public: }; } // namespace +namespace { +class RecomposeMeshgridIndexingListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + auto meshgridIndexingOp = + op.getOperand().getDefiningOp(); + if (!meshgridIndexingOp) + return rewriter.notifyMatchFailure(op, + "Input is not AtenMeshgridIndexingOp"); + Location loc = meshgridIndexingOp.getLoc(); + auto context = meshgridIndexingOp.getContext(); + auto baseType = NonValueTensorType::getWithLeastStaticInformation(context); + SmallVector tensors; + if (!getListConstructElements(meshgridIndexingOp.getTensors(), tensors)) + return rewriter.notifyMatchFailure(meshgridIndexingOp, + "Unable to get tensors"); + + int64_t numTensors = tensors.size(); + bool swapFirstAndSecondTensors = false; + + std::string indexing; + if (!matchPattern(meshgridIndexingOp.getIndexing(), + m_TorchConstantStr(indexing))) { + return rewriter.notifyMatchFailure(meshgridIndexingOp, + "Unable to get indexing"); + } + + if (indexing == "xy" && numTensors >= 2) { + swapFirstAndSecondTensors = true; + std::swap(tensors[0], tensors[1]); + } + + SmallVector expandShapeValues; + for (int64_t i = 0; i < numTensors; i++) { + expandShapeValues.push_back( + rewriter.create(loc, tensors[i])); + } + Value expandShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), expandShapeValues); + + SmallVector meshgrids; + Value constFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + for (auto [idx, tensor] : llvm::enumerate(tensors)) { + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + SmallVector tensorViewShapeValues(numTensors, constantOne); + tensorViewShapeValues[idx] = expandShapeValues[idx]; + + Value viewShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), tensorViewShapeValues); + Value view = + rewriter.create(loc, baseType, tensor, viewShapeList); + + Value expandView = rewriter.create( + loc, baseType, view, expandShapeList, constFalse); + meshgrids.push_back(expandView); + } + + if (swapFirstAndSecondTensors) { + std::swap(meshgrids[0], meshgrids[1]); + } + rewriter.replaceOp(op, meshgrids); + // erase meshgridIndexingOp if no user left + if (meshgridIndexingOp.getResult().use_empty()) + rewriter.eraseOp(meshgridIndexingOp); + return success(); + } +}; +} // namespace + namespace { class RecomposeComplexOpsPass : public RecomposeComplexOpsBase { @@ -742,6 +817,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index be9498a53..7eb3d5e4e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -821,6 +821,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { } STABLEHLO_PASS_SET = { + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", "SplitWithSizes_Module_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", @@ -1477,6 +1480,9 @@ STABLEHLO_CRASHING_SET = set() # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 106fa18ae..17c706f25 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -979,6 +979,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") + emit("aten::meshgrid : (Tensor[]) -> (Tensor[])", has_canonicalizer=True) + emit("aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])") # Str ops. emit("aten::add.str : (str, str) -> (str)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index dca86870f..46d2909eb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -57,3 +57,4 @@ def register_all_tests(): from . import padding from . import diagonal from . import gridsampler + from . import meshgrid diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py new file mode 100644 index 000000000..5cbd50473 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py @@ -0,0 +1,88 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class MeshgridIndexingIJ(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ([5], torch.int64, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.meshgrid(x, y, z, indexing="ij") + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: MeshgridIndexingIJ()) +def MeshgridIndexingIJ_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + z = torch.tensor([8, 9, 10, 11, 12]) + module.forward(x, y, z) + + +class MeshgridIndexingXY(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ([5], torch.int64, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.meshgrid(x, y, z, indexing="xy") + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: MeshgridIndexingXY()) +def MeshgridIndexingXY_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + z = torch.tensor([8, 9, 10, 11, 12]) + module.forward(x, y, z) + + +class Meshgrid(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ] + ) + def forward(self, x, y): + x1, y1 = torch.meshgrid(x, y) + return x1, y1 + + +@register_test_case(module_factory=lambda: Meshgrid()) +def Meshgrid_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + module.forward(x, y)