mirror of https://github.com/llvm/torch-mlir
[MHLO] Support IndexPut.
parent
c622f59300
commit
bba254ab1c
|
@ -1292,6 +1292,115 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Aten_IndexPutImplOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||
Aten_IndexPutImplOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
Value values = adaptor.getValues();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> inputShape(inputType.getShape());
|
||||
if (inputShape.size() != 2) {
|
||||
return op->emitError("only support 2D input in index_put");
|
||||
}
|
||||
auto valuesType = values.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> valuesShape(valuesType.getShape());
|
||||
if (valueShape.size() != 3) {
|
||||
return op->emitError("only support 3D values in index_put");
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool accumulate;
|
||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: accumulate must be a constant");
|
||||
}
|
||||
if (!accumulate) {
|
||||
return op->emitError("accumulate must be true");
|
||||
}
|
||||
|
||||
SmallVector<Value> indicesList;
|
||||
getListConstructElements(adaptor.getIndices(), indicesList);
|
||||
|
||||
// TODO: Add support for cases with indices list size not equal to 1.
|
||||
if (indicesList.size() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: Indices list size != 1");
|
||||
}
|
||||
Value index = indicesList[0];
|
||||
|
||||
if (index.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "Index tensor must not be None.");
|
||||
|
||||
index = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(index.getType()), index);
|
||||
|
||||
// input: [M, N]
|
||||
// index: [Q, P] -> [Q * P, 1]
|
||||
// values: [Q, P, N] -> [Q * P, N]
|
||||
|
||||
auto indexType = index.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> indexShape(indexType.getShape());
|
||||
if (indexShape.size() != 2) {
|
||||
return op->emitError("only support 2D index in index_put");
|
||||
}
|
||||
auto reshapedIndexType = RankedTensorType::get(
|
||||
{indexShape[0] * indexShape[1], 1},
|
||||
indexType.getElementType());
|
||||
Value reshapedIndex = rewriter.create<mhlo::ReshapeOp>(
|
||||
loc, reshapedIndexType, index);
|
||||
|
||||
auto reshapedValuesType = RankedTensorType::get(
|
||||
{valuesShape[0] * valuesShape[1], valuesShape[2]},
|
||||
valuesType.getElementType());
|
||||
Value reshapedValues = rewriter.create<mhlo::ReshapeOp>(
|
||||
loc, reshapedValuesType, values);
|
||||
|
||||
// setup ScatterDimensionNumbersAttr
|
||||
SmallVector<int64_t> updateWindowDims{1};
|
||||
SmallVector<int64_t> insertedWindowDims{0};
|
||||
SmallVector<int64_t> scatterDimsToOperandDims{0};
|
||||
int64_t indexVectorDim = indexShape.size() - 1;
|
||||
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*updateWindowDims=*/updateWindowDims,
|
||||
/*insertedWindowDims=*/insertedWindowDims,
|
||||
/*scatterDimsToOperandDims=*/scatterDimsToOperandDims,
|
||||
/*indexVectorDim=*/indexVectorDim
|
||||
);
|
||||
|
||||
BoolAttr indices_are_sorted = rewriter.getBoolAttr(false);
|
||||
BoolAttr unique_indices = rewriter.getBoolAttr(false);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
auto mhloScatterOp = rewriter.replaceOpWithNewOp<mhlo::ScatterOp>(op, outType,
|
||||
input, reshapedIndex, reshapedValues,
|
||||
scatter_dimension_numbers, indices_are_sorted, unique_indices);
|
||||
|
||||
Block &block = mhloScatterOp.getUpdateComputation().emplaceBlock();
|
||||
// Add block arguments
|
||||
auto blockValArgumentType =
|
||||
RankedTensorType::get({}, inputType.getElementType());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
auto *firstValArg = block.args_begin();
|
||||
auto *secondValArg = std::next(firstValArg);
|
||||
// create block body
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value res = rewriter.create<mhlo::AddOp>(
|
||||
op->getLoc(), *firstValArg, *secondValArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), res);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||
AtenGeluBackwardOp op, OpAdaptor adaptor,
|
||||
|
@ -1473,6 +1582,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
|
|
Loading…
Reference in New Issue