[Torch] Add support for Meshgrid (#3462)

pull/3461/merge
Xinyu Yang 2024-06-14 23:59:08 +08:00 committed by GitHub
parent a02e14e971
commit 6f94c7b0aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 235 additions and 0 deletions

View File

@ -13733,6 +13733,54 @@ def Torch_AtenChunkOp : Torch_Op<"aten.chunk", [
}];
}
def Torch_AtenMeshgridOp : Torch_Op<"aten.meshgrid", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::meshgrid : (Tensor[]) -> (Tensor[])`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMeshgridOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenMeshgridOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenMeshgridIndexingOp : Torch_Op<"aten.meshgrid.indexing", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_StringType:$indexing
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMeshgridIndexingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMeshgridIndexingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -3039,6 +3039,20 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenMeshgridOp
//===----------------------------------------------------------------------===//
void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) {
Value constIndexing = rewriter.create<Torch::ConstantStrOp>(
op->getLoc(), rewriter.getStringAttr("ij"));
rewriter.replaceOpWithNewOp<AtenMeshgridIndexingOp>(
op, op->getResultTypes(), op.getTensors(), constIndexing);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenSplitSizesOp
//===----------------------------------------------------------------------===//

View File

@ -719,6 +719,81 @@ public:
};
} // namespace
namespace {
class RecomposeMeshgridIndexingListUnpack
: public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
auto meshgridIndexingOp =
op.getOperand().getDefiningOp<Torch::AtenMeshgridIndexingOp>();
if (!meshgridIndexingOp)
return rewriter.notifyMatchFailure(op,
"Input is not AtenMeshgridIndexingOp");
Location loc = meshgridIndexingOp.getLoc();
auto context = meshgridIndexingOp.getContext();
auto baseType = NonValueTensorType::getWithLeastStaticInformation(context);
SmallVector<Value> tensors;
if (!getListConstructElements(meshgridIndexingOp.getTensors(), tensors))
return rewriter.notifyMatchFailure(meshgridIndexingOp,
"Unable to get tensors");
int64_t numTensors = tensors.size();
bool swapFirstAndSecondTensors = false;
std::string indexing;
if (!matchPattern(meshgridIndexingOp.getIndexing(),
m_TorchConstantStr(indexing))) {
return rewriter.notifyMatchFailure(meshgridIndexingOp,
"Unable to get indexing");
}
if (indexing == "xy" && numTensors >= 2) {
swapFirstAndSecondTensors = true;
std::swap(tensors[0], tensors[1]);
}
SmallVector<Value> expandShapeValues;
for (int64_t i = 0; i < numTensors; i++) {
expandShapeValues.push_back(
rewriter.create<AtenNumelOp>(loc, tensors[i]));
}
Value expandShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), expandShapeValues);
SmallVector<Value> meshgrids;
Value constFalse =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
for (auto [idx, tensor] : llvm::enumerate(tensors)) {
Value constantOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> tensorViewShapeValues(numTensors, constantOne);
tensorViewShapeValues[idx] = expandShapeValues[idx];
Value viewShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), tensorViewShapeValues);
Value view =
rewriter.create<AtenViewOp>(loc, baseType, tensor, viewShapeList);
Value expandView = rewriter.create<AtenExpandOp>(
loc, baseType, view, expandShapeList, constFalse);
meshgrids.push_back(expandView);
}
if (swapFirstAndSecondTensors) {
std::swap(meshgrids[0], meshgrids[1]);
}
rewriter.replaceOp(op, meshgrids);
// erase meshgridIndexingOp if no user left
if (meshgridIndexingOp.getResult().use_empty())
rewriter.eraseOp(meshgridIndexingOp);
return success();
}
};
} // namespace
namespace {
class RecomposeComplexOpsPass
: public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
@ -742,6 +817,7 @@ public:
patterns.add<RecomposeUnbindListUnpack>(context);
patterns.add<RecomposeUnbindGetItem>(context);
patterns.add<RecomposeChunkListUnpack>(context);
patterns.add<RecomposeMeshgridIndexingListUnpack>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -821,6 +821,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"SplitWithSizes_Module_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
@ -1477,6 +1480,9 @@ STABLEHLO_CRASHING_SET = set()
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",

View File

@ -979,6 +979,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")
emit("aten::meshgrid : (Tensor[]) -> (Tensor[])", has_canonicalizer=True)
emit("aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])")
# Str ops.
emit("aten::add.str : (str, str) -> (str)")

View File

@ -57,3 +57,4 @@ def register_all_tests():
from . import padding
from . import diagonal
from . import gridsampler
from . import meshgrid

View File

@ -0,0 +1,88 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class MeshgridIndexingIJ(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3], torch.int64, True),
([4], torch.int64, True),
([5], torch.int64, True),
]
)
def forward(self, x, y, z):
x1, y1, z1 = torch.meshgrid(x, y, z, indexing="ij")
return x1, y1, z1
@register_test_case(module_factory=lambda: MeshgridIndexingIJ())
def MeshgridIndexingIJ_basic(module, tu: TestUtils):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6, 7])
z = torch.tensor([8, 9, 10, 11, 12])
module.forward(x, y, z)
class MeshgridIndexingXY(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3], torch.int64, True),
([4], torch.int64, True),
([5], torch.int64, True),
]
)
def forward(self, x, y, z):
x1, y1, z1 = torch.meshgrid(x, y, z, indexing="xy")
return x1, y1, z1
@register_test_case(module_factory=lambda: MeshgridIndexingXY())
def MeshgridIndexingXY_basic(module, tu: TestUtils):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6, 7])
z = torch.tensor([8, 9, 10, 11, 12])
module.forward(x, y, z)
class Meshgrid(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3], torch.int64, True),
([4], torch.int64, True),
]
)
def forward(self, x, y):
x1, y1 = torch.meshgrid(x, y)
return x1, y1
@register_test_case(module_factory=lambda: Meshgrid())
def Meshgrid_basic(module, tu: TestUtils):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6, 7])
module.forward(x, y)