mirror of https://github.com/llvm/torch-mlir
[MHLO] Add torch-to-mhlo e2e support for aten.gather op (#1410)
* Add torch-to-mhlo e2e support for aten.gather op * Add more e2e tests for torch.aten.gather oppull/1415/head snapshot-20220926.608
parent
bc11e1aba6
commit
ab7aa01b1e
|
@ -24,6 +24,10 @@ EAGER_MODE_XFAIL_SET = {
|
|||
|
||||
MHLO_PASS_SET = {
|
||||
"BroadcastToIdentityCaseStaticModule_basic",
|
||||
"GatherStaticModule_basic",
|
||||
"GatherModule_basic",
|
||||
"Gather2DInputModdule_basic",
|
||||
"GatherRandomIndexModule_basic",
|
||||
"ArangeDtypeFloatModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
|
@ -166,6 +167,95 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenGatherOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||
AtenGatherOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.self();
|
||||
Value index = adaptor.index();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto indexType = index.getType().cast<RankedTensorType>();
|
||||
auto indexElemType = indexType.getElementType();
|
||||
|
||||
if (indexType.getRank() != inputType.getRank()) {
|
||||
return op.emitError("`index` and `input` param should have the same rank");
|
||||
}
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant int `dim` param supported");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputType.getRank());
|
||||
if (!isValidDim(dim, inputType.getRank())) {
|
||||
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
||||
}
|
||||
|
||||
bool sparseGrad = false;
|
||||
if (!matchPattern(op.sparse_grad(), m_TorchConstantBool(&sparseGrad))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant boolean `sparse_grad` param supported");
|
||||
}
|
||||
|
||||
auto options = getOptions();
|
||||
auto indexShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
if (failed(indexShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dim sizes of `index` param");
|
||||
}
|
||||
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
||||
toConcatIndexShapeValueVec.push_back(one);
|
||||
auto toConcatIndexShape =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, toConcatIndexShapeValueVec);
|
||||
|
||||
auto indexShape = indexType.getShape();
|
||||
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
|
||||
indexShape.end());
|
||||
toConcatIndexShapeVec.push_back(1);
|
||||
RankedTensorType toConcatIndexType =
|
||||
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
|
||||
|
||||
SmallVector<Value> toConcat;
|
||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||
if (i == dim) {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, toConcatIndexType, index, toConcatIndexShape));
|
||||
} else {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
|
||||
loc, toConcatIndexType, toConcatIndexShape,
|
||||
rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
}
|
||||
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
|
||||
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
||||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
||||
|
||||
int64_t indexVecDim = inputType.getRank();
|
||||
SmallVector<int64_t> collapsedDims;
|
||||
SmallVector<int64_t> startIndexMap;
|
||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||
collapsedDims.push_back(i);
|
||||
startIndexMap.push_back(i);
|
||||
}
|
||||
|
||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/{},
|
||||
/*collapsedSliceDims=*/collapsedDims,
|
||||
/*startIndexMap=*/startIndexMap,
|
||||
/*indexVecDim=*/indexVecDim);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
|
||||
op, input, gatherIndicies, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
|
@ -176,5 +266,6 @@ void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
|||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
||||
|
|
|
@ -643,6 +643,48 @@ def GatherModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherRandomIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: GatherRandomIndexModule())
|
||||
def GatherRandomIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Gather2DInputModdule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: Gather2DInputModdule())
|
||||
def Gather2DInputModdule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue