[MHLO] Support IndexPut.

ziheng/mhlo-index_put
Ziheng Jiang 2023-02-01 19:30:30 -08:00
parent c622f59300
commit bba254ab1c
1 changed files with 110 additions and 0 deletions

View File

@ -1292,6 +1292,115 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
return success();
}
// Aten_IndexPutImplOp
template <>
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
Aten_IndexPutImplOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value input = adaptor.getSelf();
Value values = adaptor.getValues();
auto inputType = input.getType().cast<RankedTensorType>();
SmallVector<int64_t> inputShape(inputType.getShape());
if (inputShape.size() != 2) {
return op->emitError("only support 2D input in index_put");
}
auto valuesType = values.getType().cast<RankedTensorType>();
SmallVector<int64_t> valuesShape(valuesType.getShape());
if (valueShape.size() != 3) {
return op->emitError("only support 3D values in index_put");
}
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: accumulate must be a constant");
}
if (!accumulate) {
return op->emitError("accumulate must be true");
}
SmallVector<Value> indicesList;
getListConstructElements(adaptor.getIndices(), indicesList);
// TODO: Add support for cases with indices list size not equal to 1.
if (indicesList.size() != 1) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: Indices list size != 1");
}
Value index = indicesList[0];
if (index.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Index tensor must not be None.");
index = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(index.getType()), index);
// input: [M, N]
// index: [Q, P] -> [Q * P, 1]
// values: [Q, P, N] -> [Q * P, N]
auto indexType = index.getType().cast<RankedTensorType>();
SmallVector<int64_t> indexShape(indexType.getShape());
if (indexShape.size() != 2) {
return op->emitError("only support 2D index in index_put");
}
auto reshapedIndexType = RankedTensorType::get(
{indexShape[0] * indexShape[1], 1},
indexType.getElementType());
Value reshapedIndex = rewriter.create<mhlo::ReshapeOp>(
loc, reshapedIndexType, index);
auto reshapedValuesType = RankedTensorType::get(
{valuesShape[0] * valuesShape[1], valuesShape[2]},
valuesType.getElementType());
Value reshapedValues = rewriter.create<mhlo::ReshapeOp>(
loc, reshapedValuesType, values);
// setup ScatterDimensionNumbersAttr
SmallVector<int64_t> updateWindowDims{1};
SmallVector<int64_t> insertedWindowDims{0};
SmallVector<int64_t> scatterDimsToOperandDims{0};
int64_t indexVectorDim = indexShape.size() - 1;
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
/*updateWindowDims=*/updateWindowDims,
/*insertedWindowDims=*/insertedWindowDims,
/*scatterDimsToOperandDims=*/scatterDimsToOperandDims,
/*indexVectorDim=*/indexVectorDim
);
BoolAttr indices_are_sorted = rewriter.getBoolAttr(false);
BoolAttr unique_indices = rewriter.getBoolAttr(false);
auto outType = getTypeConverter()->convertType(op.getType());
auto mhloScatterOp = rewriter.replaceOpWithNewOp<mhlo::ScatterOp>(op, outType,
input, reshapedIndex, reshapedValues,
scatter_dimension_numbers, indices_are_sorted, unique_indices);
Block &block = mhloScatterOp.getUpdateComputation().emplaceBlock();
// Add block arguments
auto blockValArgumentType =
RankedTensorType::get({}, inputType.getElementType());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *secondValArg = std::next(firstValArg);
// create block body
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value res = rewriter.create<mhlo::AddOp>(
op->getLoc(), *firstValArg, *secondValArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), res);
}
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
AtenGeluBackwardOp op, OpAdaptor adaptor,
@ -1473,6 +1582,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);