[MHLO] Add torch-to-mhlo e2e support for aten.gather op (#1410)

* Add torch-to-mhlo e2e support for aten.gather op 

* Add more e2e tests for torch.aten.gather op
pull/1415/head snapshot-20220926.608
武家伟 2022-09-25 22:07:46 +08:00 committed by GitHub
parent bc11e1aba6
commit ab7aa01b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 137 additions and 0 deletions

View File

@ -24,6 +24,10 @@ EAGER_MODE_XFAIL_SET = {
MHLO_PASS_SET = {
"BroadcastToIdentityCaseStaticModule_basic",
"GatherStaticModule_basic",
"GatherModule_basic",
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@ -166,6 +167,95 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
return success();
}
// AtenGatherOp
template <>
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value input = adaptor.self();
Value index = adaptor.index();
auto inputType = input.getType().cast<RankedTensorType>();
auto indexType = index.getType().cast<RankedTensorType>();
auto indexElemType = indexType.getElementType();
if (indexType.getRank() != inputType.getRank()) {
return op.emitError("`index` and `input` param should have the same rank");
}
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "only constant int `dim` param supported");
}
dim = toPositiveDim(dim, inputType.getRank());
if (!isValidDim(dim, inputType.getRank())) {
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
}
bool sparseGrad = false;
if (!matchPattern(op.sparse_grad(), m_TorchConstantBool(&sparseGrad))) {
return rewriter.notifyMatchFailure(
op, "only constant boolean `sparse_grad` param supported");
}
auto options = getOptions();
auto indexShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param");
}
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
auto toConcatIndexShapeValueVec = *indexShapeInfo;
toConcatIndexShapeValueVec.push_back(one);
auto toConcatIndexShape =
rewriter.create<tensor::FromElementsOp>(loc, toConcatIndexShapeValueVec);
auto indexShape = indexType.getShape();
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
indexShape.end());
toConcatIndexShapeVec.push_back(1);
RankedTensorType toConcatIndexType =
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, toConcatIndexType, index, toConcatIndexShape));
} else {
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
loc, toConcatIndexType, toConcatIndexShape,
rewriter.getI64IntegerAttr(i)));
}
}
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
int64_t indexVecDim = inputType.getRank();
SmallVector<int64_t> collapsedDims;
SmallVector<int64_t> startIndexMap;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
collapsedDims.push_back(i);
startIndexMap.push_back(i);
}
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
return success();
}
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
@ -176,5 +266,6 @@ void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
#undef INSERT_ATENOP_PATTERN
}

View File

@ -643,6 +643,48 @@ def GatherModule_basic(module, tu: TestUtils):
# ==============================================================================
class GatherRandomIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 1, indices)
@register_test_case(module_factory=lambda: GatherRandomIndexModule())
def GatherRandomIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3))
# ==============================================================================
class Gather2DInputModdule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 1, indices)
@register_test_case(module_factory=lambda: Gather2DInputModdule())
def Gather2DInputModdule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]]))
# ==============================================================================
class GatherStaticModule(torch.nn.Module):
def __init__(self):