mirror of https://github.com/llvm/torch-mlir
[Torch] Add support for Meshgrid (#3462)
parent
a02e14e971
commit
6f94c7b0aa
|
@ -13733,6 +13733,54 @@ def Torch_AtenChunkOp : Torch_Op<"aten.chunk", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMeshgridOp : Torch_Op<"aten.meshgrid", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::meshgrid : (Tensor[]) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTensorType:$tensors
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMeshgridOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenMeshgridOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMeshgridIndexingOp : Torch_Op<"aten.meshgrid.indexing", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
Torch_StringType:$indexing
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMeshgridIndexingOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenMeshgridIndexingOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -3039,6 +3039,20 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMeshgridOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) {
|
||||
Value constIndexing = rewriter.create<Torch::ConstantStrOp>(
|
||||
op->getLoc(), rewriter.getStringAttr("ij"));
|
||||
rewriter.replaceOpWithNewOp<AtenMeshgridIndexingOp>(
|
||||
op, op->getResultTypes(), op.getTensors(), constIndexing);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSplitSizesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -719,6 +719,81 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeMeshgridIndexingListUnpack
|
||||
: public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto meshgridIndexingOp =
|
||||
op.getOperand().getDefiningOp<Torch::AtenMeshgridIndexingOp>();
|
||||
if (!meshgridIndexingOp)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Input is not AtenMeshgridIndexingOp");
|
||||
Location loc = meshgridIndexingOp.getLoc();
|
||||
auto context = meshgridIndexingOp.getContext();
|
||||
auto baseType = NonValueTensorType::getWithLeastStaticInformation(context);
|
||||
SmallVector<Value> tensors;
|
||||
if (!getListConstructElements(meshgridIndexingOp.getTensors(), tensors))
|
||||
return rewriter.notifyMatchFailure(meshgridIndexingOp,
|
||||
"Unable to get tensors");
|
||||
|
||||
int64_t numTensors = tensors.size();
|
||||
bool swapFirstAndSecondTensors = false;
|
||||
|
||||
std::string indexing;
|
||||
if (!matchPattern(meshgridIndexingOp.getIndexing(),
|
||||
m_TorchConstantStr(indexing))) {
|
||||
return rewriter.notifyMatchFailure(meshgridIndexingOp,
|
||||
"Unable to get indexing");
|
||||
}
|
||||
|
||||
if (indexing == "xy" && numTensors >= 2) {
|
||||
swapFirstAndSecondTensors = true;
|
||||
std::swap(tensors[0], tensors[1]);
|
||||
}
|
||||
|
||||
SmallVector<Value> expandShapeValues;
|
||||
for (int64_t i = 0; i < numTensors; i++) {
|
||||
expandShapeValues.push_back(
|
||||
rewriter.create<AtenNumelOp>(loc, tensors[i]));
|
||||
}
|
||||
Value expandShapeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), expandShapeValues);
|
||||
|
||||
SmallVector<Value> meshgrids;
|
||||
Value constFalse =
|
||||
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
|
||||
|
||||
for (auto [idx, tensor] : llvm::enumerate(tensors)) {
|
||||
Value constantOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<Value> tensorViewShapeValues(numTensors, constantOne);
|
||||
tensorViewShapeValues[idx] = expandShapeValues[idx];
|
||||
|
||||
Value viewShapeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), tensorViewShapeValues);
|
||||
Value view =
|
||||
rewriter.create<AtenViewOp>(loc, baseType, tensor, viewShapeList);
|
||||
|
||||
Value expandView = rewriter.create<AtenExpandOp>(
|
||||
loc, baseType, view, expandShapeList, constFalse);
|
||||
meshgrids.push_back(expandView);
|
||||
}
|
||||
|
||||
if (swapFirstAndSecondTensors) {
|
||||
std::swap(meshgrids[0], meshgrids[1]);
|
||||
}
|
||||
rewriter.replaceOp(op, meshgrids);
|
||||
// erase meshgridIndexingOp if no user left
|
||||
if (meshgridIndexingOp.getResult().use_empty())
|
||||
rewriter.eraseOp(meshgridIndexingOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeComplexOpsPass
|
||||
: public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
|
||||
|
@ -742,6 +817,7 @@ public:
|
|||
patterns.add<RecomposeUnbindListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindGetItem>(context);
|
||||
patterns.add<RecomposeChunkListUnpack>(context);
|
||||
patterns.add<RecomposeMeshgridIndexingListUnpack>(context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -821,6 +821,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
|
@ -1477,6 +1480,9 @@ STABLEHLO_CRASHING_SET = set()
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
|
|
|
@ -979,6 +979,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
|
||||
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::meshgrid : (Tensor[]) -> (Tensor[])", has_canonicalizer=True)
|
||||
emit("aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])")
|
||||
|
||||
# Str ops.
|
||||
emit("aten::add.str : (str, str) -> (str)")
|
||||
|
|
|
@ -57,3 +57,4 @@ def register_all_tests():
|
|||
from . import padding
|
||||
from . import diagonal
|
||||
from . import gridsampler
|
||||
from . import meshgrid
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeshgridIndexingIJ(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3], torch.int64, True),
|
||||
([4], torch.int64, True),
|
||||
([5], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y, z):
|
||||
x1, y1, z1 = torch.meshgrid(x, y, z, indexing="ij")
|
||||
return x1, y1, z1
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeshgridIndexingIJ())
|
||||
def MeshgridIndexingIJ_basic(module, tu: TestUtils):
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6, 7])
|
||||
z = torch.tensor([8, 9, 10, 11, 12])
|
||||
module.forward(x, y, z)
|
||||
|
||||
|
||||
class MeshgridIndexingXY(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3], torch.int64, True),
|
||||
([4], torch.int64, True),
|
||||
([5], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y, z):
|
||||
x1, y1, z1 = torch.meshgrid(x, y, z, indexing="xy")
|
||||
return x1, y1, z1
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeshgridIndexingXY())
|
||||
def MeshgridIndexingXY_basic(module, tu: TestUtils):
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6, 7])
|
||||
z = torch.tensor([8, 9, 10, 11, 12])
|
||||
module.forward(x, y, z)
|
||||
|
||||
|
||||
class Meshgrid(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3], torch.int64, True),
|
||||
([4], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
x1, y1 = torch.meshgrid(x, y)
|
||||
return x1, y1
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Meshgrid())
|
||||
def Meshgrid_basic(module, tu: TestUtils):
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6, 7])
|
||||
module.forward(x, y)
|
Loading…
Reference in New Issue