2022-07-25 23:47:46 +08:00
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
2023-05-25 02:13:57 +08:00
|
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2023-05-25 02:13:57 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::torch;
|
|
|
|
|
using namespace mlir::torch::Torch;
|
2023-02-02 21:29:47 +08:00
|
|
|
|
using namespace mlir::torch::torch_to_stablehlo;
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
|
|
|
|
namespace {
|
2023-09-05 21:28:37 +08:00
|
|
|
|
static Value createInitialValueForGatherScatterOp(Operation *op,
|
2024-01-30 01:59:33 +08:00
|
|
|
|
RankedTensorType constType,
|
|
|
|
|
PatternRewriter &rewriter) {
|
2024-05-26 12:34:56 +08:00
|
|
|
|
if (!constType.hasStaticShape()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
2023-09-05 21:28:37 +08:00
|
|
|
|
auto elementTy = constType.getElementType();
|
|
|
|
|
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2023-09-05 21:28:37 +08:00
|
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
|
constType, {APFloat::getZero(
|
2024-04-11 21:47:35 +08:00
|
|
|
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
2023-09-05 21:28:37 +08:00
|
|
|
|
/*negative=*/false)});
|
|
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
|
constAttr);
|
2024-04-11 21:47:35 +08:00
|
|
|
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
2023-09-05 21:28:37 +08:00
|
|
|
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
|
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
|
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
|
constAttr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->emitError("unimplemented lowering in "
|
|
|
|
|
"createInitialValueForGatherScatterOp");
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-25 23:47:46 +08:00
|
|
|
|
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value input, Value indices, int64_t axis,
|
|
|
|
|
size_t dimSizeIndexBits) {
|
2022-07-25 23:47:46 +08:00
|
|
|
|
auto loc = op->getLoc();
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
|
|
|
|
|
// sliceSizes
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputRankTy = dyn_cast<RankedTensorType>(input.getType());
|
2022-07-25 23:47:46 +08:00
|
|
|
|
auto inputRank = inputRankTy.getRank();
|
|
|
|
|
SmallVector<Value, 4> sliceSizes;
|
|
|
|
|
sliceSizes.reserve(inputRank);
|
|
|
|
|
for (int64_t r = 0; r < inputRank; ++r) {
|
|
|
|
|
if (r == axis) {
|
|
|
|
|
sliceSizes.push_back(one);
|
|
|
|
|
} else {
|
|
|
|
|
sliceSizes.push_back(rewriter.create<arith::IndexCastOp>(
|
|
|
|
|
loc, intType, rewriter.create<tensor::DimOp>(loc, input, r)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto sliceSizesTensor =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, sliceSizes);
|
|
|
|
|
|
|
|
|
|
// offsetDims
|
|
|
|
|
SmallVector<int64_t, 4> offsetDims;
|
|
|
|
|
offsetDims.reserve(inputRank);
|
|
|
|
|
for (int64_t r = 0; r < axis; ++r) {
|
|
|
|
|
offsetDims.push_back(r);
|
|
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto indicesRankTy = dyn_cast<RankedTensorType>(indices.getType());
|
2022-07-25 23:47:46 +08:00
|
|
|
|
auto indicesRank = indicesRankTy.getRank();
|
|
|
|
|
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
|
|
|
|
offsetDims.push_back(r + indicesRank - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// collapsedSliceDims
|
|
|
|
|
SmallVector<int64_t, 4> collapsedSliceDims(1, axis);
|
|
|
|
|
// startIndexMap
|
|
|
|
|
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
|
|
|
|
// indexVecDim
|
|
|
|
|
int64_t indexVecDim = indicesRank;
|
2023-02-02 21:29:47 +08:00
|
|
|
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
2022-07-25 23:47:46 +08:00
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*offsetDims=*/offsetDims,
|
|
|
|
|
/*collapsedSliceDims=*/collapsedSliceDims,
|
2024-05-22 23:28:45 +08:00
|
|
|
|
/*operandBatchingDims=*/{},
|
|
|
|
|
/*startIndicesBatchingDims=*/{},
|
2022-07-25 23:47:46 +08:00
|
|
|
|
/*startIndexMap=*/startIndexMap,
|
|
|
|
|
/*indexVecDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
// outputShape = input.shape[:axis] + indices.shape +
|
|
|
|
|
// input.shape[axis + 1:]
|
|
|
|
|
auto inputShape = inputRankTy.getShape();
|
|
|
|
|
auto indicesShape = indicesRankTy.getShape();
|
|
|
|
|
SmallVector<int64_t, 4> outputShape(inputShape.begin(),
|
|
|
|
|
inputShape.begin() + axis);
|
|
|
|
|
outputShape.insert(outputShape.end(), indicesShape.begin(),
|
|
|
|
|
indicesShape.end());
|
|
|
|
|
outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1,
|
|
|
|
|
inputShape.end());
|
|
|
|
|
|
|
|
|
|
// create output tensor type
|
|
|
|
|
auto outputTy =
|
|
|
|
|
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
|
|
|
|
return rewriter
|
2023-02-02 21:29:47 +08:00
|
|
|
|
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
|
|
|
|
sliceSizesTensor, dimsAttr)
|
2022-07-25 23:47:46 +08:00
|
|
|
|
.getResult();
|
|
|
|
|
}
|
2023-03-23 04:41:04 +08:00
|
|
|
|
|
|
|
|
|
template <typename OpTy, typename OpAdaptor>
|
|
|
|
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
SmallVector<Value> &resultShape,
|
|
|
|
|
SmallVector<Value> &offsets,
|
|
|
|
|
SmallVector<Value> &strides) {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
2023-03-23 04:41:04 +08:00
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
|
|
|
|
|
|
int64_t dim;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
|
|
|
|
return op->emitError("unimplemented: dim is not constant");
|
|
|
|
|
|
|
|
|
|
int64_t inputRank = inputType.getRank();
|
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
|
if (!isValidDim(dim, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
|
|
|
|
Value dimSize = inputShape[dim];
|
|
|
|
|
|
|
|
|
|
Value torchTypeStart = op.getStart();
|
|
|
|
|
Value torchTypeEnd = op.getEnd();
|
|
|
|
|
Value builtinTypeStart = adaptor.getStart();
|
|
|
|
|
Value builtinTypeEnd = adaptor.getEnd();
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
if (isa<OptionalType>(torchTypeStart.getType()) ||
|
|
|
|
|
isa<OptionalType>(torchTypeEnd.getType()))
|
2023-03-23 04:41:04 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
|
|
|
|
|
|
|
|
|
int64_t step;
|
|
|
|
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
if (!isa<Torch::NoneType>(op.getStep().getType()))
|
2023-03-23 04:41:04 +08:00
|
|
|
|
return op->emitError("unimplemented: step is not constant");
|
|
|
|
|
step = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
|
|
|
|
builtinTypeStart, zero, dimSize);
|
|
|
|
|
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
|
|
|
|
dimSize, dimSize);
|
|
|
|
|
|
|
|
|
|
// end >= start ? end : start
|
|
|
|
|
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
|
|
|
|
loc, arith::CmpIPredicate::sge, end, start);
|
|
|
|
|
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
|
|
|
|
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
|
|
|
|
|
|
|
|
|
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
|
|
|
|
resultShape = getTensorSizes(rewriter, loc, input);
|
|
|
|
|
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
|
|
|
|
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
|
|
|
|
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
|
|
|
|
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
|
|
|
|
resultShape[dim] = resultSize;
|
|
|
|
|
|
|
|
|
|
strides.resize(inputType.getRank(), one);
|
|
|
|
|
offsets.resize(inputType.getRank(), zero);
|
|
|
|
|
|
|
|
|
|
offsets[dim] = start;
|
|
|
|
|
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2022-09-01 10:36:02 +08:00
|
|
|
|
} // namespace
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// A helper function used to generate stablehlo's ScatterIndices or
|
|
|
|
|
// GatherIndices from torch's indices, usually appear in torch ops, like
|
|
|
|
|
// aten.index.Tensor or aten.input_put A usage example is as follow: Input: [[1,
|
|
|
|
|
// 2, 3],
|
|
|
|
|
// [4, 5, 6],
|
|
|
|
|
// [7, 8, 9]]
|
|
|
|
|
// Indices[0]: [[0, 0, 0],
|
|
|
|
|
// [2, 2, 0]]
|
|
|
|
|
// Indices[1]: [[2],
|
|
|
|
|
// [1]]
|
|
|
|
|
// Step 1: broadcast indices tensors
|
|
|
|
|
// Indices[0]: [[0, 0, 0],
|
|
|
|
|
// [2, 2, 0]]
|
|
|
|
|
// Indices[1]: [[2, 2, 2],
|
|
|
|
|
// [1, 1, 1]]
|
|
|
|
|
// Step 2: concat index tensors at a unsqueezed -1 dimension.
|
|
|
|
|
// Indices: [[[0, 2], [0, 2], [0, 2]],
|
|
|
|
|
// [[2, 1], [2, 1], [0, 1]]]
|
|
|
|
|
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
SmallVector<Value> indexTensors,
|
|
|
|
|
llvm::ArrayRef<int64_t> inputShape,
|
|
|
|
|
int &maxIndexRank) {
|
|
|
|
|
// Step 1: broadcast indices tensors
|
|
|
|
|
SmallVector<int64_t> indicesShape;
|
|
|
|
|
SmallVector<int64_t> expandShape;
|
|
|
|
|
SmallVector<int64_t> concatShape;
|
|
|
|
|
// concat index tensor into to indices tensor for concat
|
|
|
|
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
|
|
|
|
auto indexTensor = indexTensors[i];
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
|
2024-04-01 19:39:49 +08:00
|
|
|
|
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
|
|
|
|
if (size == kUnknownSize)
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> refinedInputShape = makeShapeTorchCompatible(inputShape);
|
|
|
|
|
for (int64_t size : refinedInputShape) {
|
|
|
|
|
if (size == kUnknownSize) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < maxIndexRank; i++) {
|
|
|
|
|
indicesShape.push_back(refinedInputShape[i]);
|
|
|
|
|
expandShape.push_back(refinedInputShape[i]);
|
|
|
|
|
concatShape.push_back(refinedInputShape[i]);
|
|
|
|
|
}
|
|
|
|
|
expandShape.push_back(1);
|
|
|
|
|
concatShape.push_back(indexTensors.size());
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> broadcastedIndices;
|
2024-05-16 15:33:23 +08:00
|
|
|
|
Type indexElemTy = rewriter.getI64Type();
|
2024-04-01 19:39:49 +08:00
|
|
|
|
RankedTensorType bcastIndexType =
|
|
|
|
|
RankedTensorType::get(indicesShape, indexElemTy);
|
|
|
|
|
for (auto indexTensor : indexTensors) {
|
|
|
|
|
Value bcastVal =
|
|
|
|
|
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
|
|
|
|
|
RankedTensorType reshapeType =
|
|
|
|
|
RankedTensorType::get(expandShape, indexElemTy);
|
|
|
|
|
bcastVal = rewriter.create<stablehlo::ReshapeOp>(op->getLoc(), reshapeType,
|
|
|
|
|
bcastVal);
|
|
|
|
|
broadcastedIndices.push_back(bcastVal);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 2: concat index tensors at a unsqueezed -1 dimension.
|
|
|
|
|
Value finalIndexTensor = broadcastedIndices[0];
|
|
|
|
|
if (broadcastedIndices.size() > 1) {
|
|
|
|
|
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
|
|
|
|
|
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
|
|
|
|
|
op->getLoc(), concatTy, ValueRange(broadcastedIndices),
|
|
|
|
|
concatShape.size() - 1);
|
|
|
|
|
}
|
|
|
|
|
return finalIndexTensor;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
|
// Ref:
|
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
2022-07-25 23:47:46 +08:00
|
|
|
|
// padding_idx (int, optional)
|
2023-02-02 21:29:47 +08:00
|
|
|
|
// – If specified, the entries at padding_idx do not contribute to the
|
|
|
|
|
// gradient; therefore, the embedding vector at padding_idx is not updated
|
|
|
|
|
// during training, i.e. it remains as a fixed “pad”.
|
2022-07-25 23:47:46 +08:00
|
|
|
|
// scale_grad_by_freq (boolean, optional)
|
|
|
|
|
// – If given, this will scale gradients by the inverse of frequency of the
|
|
|
|
|
// words in the mini-batch. Default False.
|
|
|
|
|
// sparse (bool, optional)
|
|
|
|
|
// – If True, gradient w.r.t. weight matrix will be a sparse tensor.
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|
|
|
|
AtenEmbeddingOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
auto weight = adaptor.getWeight();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
2022-07-25 23:47:46 +08:00
|
|
|
|
if (!weightTy)
|
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getPaddingIdx(), m_TorchConstantInt(&padding_idx)))
|
2022-07-25 23:47:46 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant padding_idx is currently supported");
|
|
|
|
|
|
|
|
|
|
bool scale_grad_by_freq;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getScaleGradByFreq(),
|
2022-07-25 23:47:46 +08:00
|
|
|
|
m_TorchConstantBool(&scale_grad_by_freq)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant scale_grad_by_freq is currently supported");
|
|
|
|
|
if (scale_grad_by_freq)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "scale gradients is currently not supported");
|
|
|
|
|
bool sparse;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse)))
|
2022-07-25 23:47:46 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant sparse is currently supported");
|
|
|
|
|
if (sparse)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "sparse gradients is currently not supported");
|
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value output = gatherTensorAlongSingleAxis(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
2023-02-02 21:29:47 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
2022-07-25 23:47:46 +08:00
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), output);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-05 21:28:37 +08:00
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|
|
|
|
AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value weight = adaptor.getWeight();
|
|
|
|
|
Value indices = adaptor.getIndices();
|
|
|
|
|
Value offsets = adaptor.getOffsets();
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
2023-09-05 21:28:37 +08:00
|
|
|
|
if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "weight must be rank 2 tensor with static shapes");
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
2023-09-05 21:28:37 +08:00
|
|
|
|
if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "indices must be a vector with static shapes");
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto offsetsTy = cast<RankedTensorType>(offsets.getType());
|
2023-09-05 21:28:37 +08:00
|
|
|
|
if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() &&
|
|
|
|
|
offsetsTy.getShape()[0] == 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "offsets must be a vector with static shape equal to 1");
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
if (!isa<Torch::NoneType>(op.getPaddingIdx().getType()))
|
2023-09-05 21:28:37 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: padding_idx should be none");
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
if (!isa<Torch::NoneType>(op.getPerSampleWeights().getType()))
|
2023-09-05 21:28:37 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: per_sample_weights should be none");
|
|
|
|
|
|
|
|
|
|
bool includeLastOffset;
|
|
|
|
|
if (!matchPattern(op.getIncludeLastOffset(),
|
|
|
|
|
m_TorchConstantBool(&includeLastOffset))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "include_last_offset is expected to be a constant boolean value.");
|
|
|
|
|
}
|
|
|
|
|
if (includeLastOffset)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "include_last_offset is currently not supported");
|
|
|
|
|
|
|
|
|
|
bool scaleGradByFreq;
|
|
|
|
|
if (!matchPattern(op.getScaleGradByFreq(),
|
|
|
|
|
m_TorchConstantBool(&scaleGradByFreq)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant scale_grad_by_freq is currently supported");
|
|
|
|
|
if (scaleGradByFreq)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "scale gradients is currently not supported");
|
|
|
|
|
|
|
|
|
|
bool sparse;
|
|
|
|
|
if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant sparse is currently supported");
|
|
|
|
|
if (sparse)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "sparse gradients is currently not supported");
|
|
|
|
|
|
|
|
|
|
int64_t modeInt;
|
|
|
|
|
if (!matchPattern(op.getMode(), m_TorchConstantInt(&modeInt))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "mode is expected to be a constant integer value.");
|
|
|
|
|
}
|
|
|
|
|
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Unimplemented: Mean and Max mode are "
|
|
|
|
|
"not supported yet for EmbeddingBag.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &options =
|
|
|
|
|
ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::getOptions();
|
|
|
|
|
auto weightDimSizes =
|
|
|
|
|
*hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits);
|
|
|
|
|
auto indicesDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, indices,
|
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
|
auto offsetsDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, offsets,
|
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
|
|
|
|
|
|
Value gatherOutput = gatherTensorAlongSingleAxis(
|
|
|
|
|
rewriter, op, weight, indices, 0, options.dimSizeIndexBits);
|
|
|
|
|
|
|
|
|
|
Type elementTy = weightTy.getElementType();
|
|
|
|
|
auto constType = RankedTensorType::get({}, elementTy);
|
|
|
|
|
Value initValue =
|
|
|
|
|
createInitialValueForGatherScatterOp(op, constType, rewriter);
|
|
|
|
|
if (!initValue)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
|
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
|
|
|
|
|
elementTy);
|
2023-09-05 21:28:37 +08:00
|
|
|
|
|
|
|
|
|
Region ®ion = stablehloReduceOp.getBody();
|
|
|
|
|
Block &block = region.emplaceBlock();
|
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, elementTy);
|
|
|
|
|
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
|
|
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
Value addResult = rewriter.create<stablehlo::AddOp>(
|
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto outShapeInfo =
|
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits);
|
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
}
|
|
|
|
|
auto outShapeVec = *outShapeInfo;
|
|
|
|
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
|
|
|
|
op->getLoc(), rewriter.getIntegerAttr(
|
|
|
|
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
|
|
|
|
outShapeVec[0] = one;
|
|
|
|
|
auto outShapeTensor =
|
|
|
|
|
rewriter.create<mlir::tensor::FromElementsOp>(op->getLoc(), outShapeVec);
|
|
|
|
|
auto resultA = rewriter.create<stablehlo::DynamicReshapeOp>(
|
|
|
|
|
loc, getTypeConverter()->convertType(op.getType(0)),
|
|
|
|
|
stablehloReduceOp.getResult(0), outShapeTensor);
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
RankedTensorType resultType = cast<RankedTensorType>(
|
|
|
|
|
getTypeConverter()->convertType(op->getResult(1).getType()));
|
2023-09-05 21:28:37 +08:00
|
|
|
|
Value resultB =
|
|
|
|
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
|
|
|
|
if (!resultB)
|
|
|
|
|
return failure();
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
resultType = cast<RankedTensorType>(
|
|
|
|
|
getTypeConverter()->convertType(op->getResult(2).getType()));
|
2023-09-05 21:28:37 +08:00
|
|
|
|
Value resultC =
|
|
|
|
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
|
|
|
|
if (!resultC)
|
|
|
|
|
return failure();
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
resultType = cast<RankedTensorType>(
|
|
|
|
|
getTypeConverter()->convertType(op->getResult(3).getType()));
|
2023-09-05 21:28:37 +08:00
|
|
|
|
Value resultD =
|
|
|
|
|
createInitialValueForGatherScatterOp(op, resultType, rewriter);
|
|
|
|
|
if (!resultD)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, {resultA, resultB, resultC, resultD});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-25 23:47:46 +08:00
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|
|
|
|
AtenIndexSelectOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
auto self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
2022-07-25 23:47:46 +08:00
|
|
|
|
if (!selfTy)
|
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
2022-07-25 23:47:46 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant dim is currently supported");
|
2023-04-07 19:49:35 +08:00
|
|
|
|
int64_t inputRank = selfTy.getRank();
|
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
|
if (!isValidDim(dim, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value output = gatherTensorAlongSingleAxis(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
2022-07-25 23:47:46 +08:00
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), output);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-25 22:07:46 +08:00
|
|
|
|
// AtenGatherOp
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|
|
|
|
AtenGatherOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
Value index = adaptor.getIndex();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
|
|
|
auto indexType = cast<RankedTensorType>(index.getType());
|
2022-09-25 22:07:46 +08:00
|
|
|
|
auto indexElemType = indexType.getElementType();
|
|
|
|
|
|
|
|
|
|
if (indexType.getRank() != inputType.getRank()) {
|
|
|
|
|
return op.emitError("`index` and `input` param should have the same rank");
|
|
|
|
|
}
|
|
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
2022-09-25 22:07:46 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant int `dim` param supported");
|
|
|
|
|
}
|
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
|
if (!isValidDim(dim, inputType.getRank())) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool sparseGrad = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad))) {
|
2022-09-25 22:07:46 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant boolean `sparse_grad` param supported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto options = getOptions();
|
|
|
|
|
auto indexShapeInfo =
|
2023-02-02 21:29:47 +08:00
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
2022-09-25 22:07:46 +08:00
|
|
|
|
if (failed(indexShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dim sizes of `index` param");
|
|
|
|
|
}
|
|
|
|
|
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
|
|
|
|
auto one = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
|
|
|
|
toConcatIndexShapeValueVec.push_back(one);
|
|
|
|
|
auto toConcatIndexShape =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, toConcatIndexShapeValueVec);
|
|
|
|
|
|
|
|
|
|
auto indexShape = indexType.getShape();
|
|
|
|
|
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
|
|
|
|
|
indexShape.end());
|
|
|
|
|
toConcatIndexShapeVec.push_back(1);
|
|
|
|
|
RankedTensorType toConcatIndexType =
|
|
|
|
|
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> toConcat;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
if (i == dim) {
|
2023-02-02 21:29:47 +08:00
|
|
|
|
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
2022-09-25 22:07:46 +08:00
|
|
|
|
loc, toConcatIndexType, index, toConcatIndexShape));
|
|
|
|
|
} else {
|
2023-02-02 21:29:47 +08:00
|
|
|
|
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
2022-09-25 22:07:46 +08:00
|
|
|
|
loc, toConcatIndexType, toConcatIndexShape,
|
|
|
|
|
rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
|
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
2022-09-25 22:07:46 +08:00
|
|
|
|
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
|
|
|
|
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
|
|
|
|
|
|
|
|
|
int64_t indexVecDim = inputType.getRank();
|
|
|
|
|
SmallVector<int64_t> collapsedDims;
|
|
|
|
|
SmallVector<int64_t> startIndexMap;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
collapsedDims.push_back(i);
|
|
|
|
|
startIndexMap.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
2022-09-25 22:07:46 +08:00
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*offsetDims=*/{},
|
|
|
|
|
/*collapsedSliceDims=*/collapsedDims,
|
2024-05-22 23:28:45 +08:00
|
|
|
|
/*operandBatchingDims=*/{},
|
|
|
|
|
/*startIndicesBatchingDims=*/{},
|
2022-09-25 22:07:46 +08:00
|
|
|
|
/*startIndexMap=*/startIndexMap,
|
|
|
|
|
/*indexVecDim=*/indexVecDim);
|
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
2022-09-25 22:07:46 +08:00
|
|
|
|
op, input, gatherIndicies, dimsAttr,
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
|
rewriter.getDenseI64ArrayAttr(sliceSizes));
|
2022-09-25 22:07:46 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-23 04:41:04 +08:00
|
|
|
|
// AtenSliceScatterOp
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
|
|
|
|
AtenSliceScatterOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
|
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
2023-08-16 00:53:28 +08:00
|
|
|
|
const TypeConverter *typeConverter = getTypeConverter();
|
2023-03-23 04:41:04 +08:00
|
|
|
|
|
|
|
|
|
auto input = adaptor.getSelf();
|
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
|
RankedTensorType resultType = cast<RankedTensorType>(
|
|
|
|
|
typeConverter->convertType(op->getResult(0).getType()));
|
2023-03-23 04:41:04 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
|
SmallVector<Value> offsets;
|
|
|
|
|
SmallVector<Value> strides;
|
|
|
|
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
|
|
|
|
AtenSliceScatterOpAdaptor>(
|
|
|
|
|
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value src = adaptor.getSrc();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto srcType = cast<RankedTensorType>(src.getType());
|
2023-03-23 04:41:04 +08:00
|
|
|
|
int64_t srcRank = srcType.getRank();
|
|
|
|
|
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
|
|
|
|
auto abstractSrcType = RankedTensorType::get(
|
|
|
|
|
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
|
|
|
|
Value abstractSrc =
|
|
|
|
|
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
|
|
|
|
|
|
|
|
|
Value result = rewriter.create<tensor::InsertSliceOp>(
|
|
|
|
|
loc, abstractSrc, input, offsets, resultShape, strides);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
template <typename AtenOpT, int reduceType>
|
|
|
|
|
class ConvertAtenScatterOp : public ConvertAtenOp<AtenOpT> {
|
|
|
|
|
public:
|
|
|
|
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
Value index = adaptor.getIndex();
|
|
|
|
|
Value src = adaptor.getSrc();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
|
|
|
auto indexType = cast<RankedTensorType>(index.getType());
|
|
|
|
|
auto srcType = cast<RankedTensorType>(src.getType());
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto indexElemType = indexType.getElementType();
|
|
|
|
|
|
|
|
|
|
if (indexType.getRank() != inputType.getRank() ||
|
|
|
|
|
inputType.getRank() != srcType.getRank()) {
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"`index`, `input` and `src` param should have the same rank");
|
|
|
|
|
}
|
|
|
|
|
int64_t dim;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant int `dim` param supported");
|
|
|
|
|
}
|
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
|
if (!isValidDim(dim, inputType.getRank())) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
2023-07-24 10:14:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto options = this->getOptions();
|
2023-07-24 10:14:45 +08:00
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto indexShapeInfo =
|
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
|
|
|
|
if (failed(indexShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dim sizes of `index` param");
|
|
|
|
|
}
|
|
|
|
|
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
|
|
|
|
|
|
|
|
|
// slice src tensor to have the same shape bound of index tensor in the
|
|
|
|
|
// leading dimensions. PyTorch has guaranteed that src tensor size will not
|
|
|
|
|
// be smaller than that of index tensor. REF:
|
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
|
|
|
|
|
auto zero = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 0));
|
|
|
|
|
auto one = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
SmallVector<Value> sliceIndicies(srcType.getRank(), zero);
|
|
|
|
|
SmallVector<Value> sliceStrides(srcType.getRank(), one);
|
|
|
|
|
|
|
|
|
|
auto sliceIndiciesValue =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, sliceIndicies);
|
|
|
|
|
auto sliceStridesValue =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, sliceStrides);
|
|
|
|
|
auto sliceLimitIndiciesValue =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, *indexShapeInfo);
|
|
|
|
|
|
|
|
|
|
auto newSrcType =
|
|
|
|
|
RankedTensorType::get(indexType.getShape(), srcType.getElementType());
|
|
|
|
|
src = rewriter.create<stablehlo::RealDynamicSliceOp>(
|
|
|
|
|
loc, newSrcType, src, sliceIndiciesValue, sliceLimitIndiciesValue,
|
|
|
|
|
sliceStridesValue);
|
|
|
|
|
|
|
|
|
|
// generate scatter indicies for stablehlo::Scatter op.
|
|
|
|
|
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
|
|
|
|
toConcatIndexShapeValueVec.push_back(one);
|
|
|
|
|
auto toConcatIndexShape = rewriter.create<tensor::FromElementsOp>(
|
|
|
|
|
loc, toConcatIndexShapeValueVec);
|
|
|
|
|
|
|
|
|
|
auto indexShape = indexType.getShape();
|
|
|
|
|
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
|
|
|
|
|
indexShape.end());
|
|
|
|
|
toConcatIndexShapeVec.push_back(1);
|
|
|
|
|
RankedTensorType toConcatIndexType =
|
|
|
|
|
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> toConcat;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
if (i == dim) {
|
|
|
|
|
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
|
|
|
|
loc, toConcatIndexType, index, toConcatIndexShape));
|
|
|
|
|
} else {
|
|
|
|
|
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
|
|
|
|
loc, toConcatIndexType, toConcatIndexShape,
|
|
|
|
|
rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-07-24 10:14:45 +08:00
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto scatterIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
|
|
|
|
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
|
|
|
|
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
|
|
|
|
|
|
|
|
|
// generate ScatterDimensionNumbers for stablehlo::Scatter op.
|
|
|
|
|
int64_t indexVecDim = inputType.getRank();
|
|
|
|
|
SmallVector<int64_t> scatterDimOperandDimMap;
|
|
|
|
|
SmallVector<int64_t> insertedWindowDims;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
scatterDimOperandDimMap.push_back(i);
|
|
|
|
|
insertedWindowDims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*updateWindowDims=*/{},
|
|
|
|
|
/*insertedWindowDims=*/insertedWindowDims,
|
2024-05-22 23:28:45 +08:00
|
|
|
|
/*inputBatchingDims=*/{},
|
|
|
|
|
/*scatterIndicesBatchingDims=*/{},
|
2024-04-01 19:39:49 +08:00
|
|
|
|
/*scatterDimsToOperandDim=*/scatterDimOperandDimMap,
|
|
|
|
|
/*indexVectorDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
|
|
|
|
|
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
|
|
|
|
|
false, false);
|
|
|
|
|
|
|
|
|
|
// config update computation function: just return the element from src.
|
|
|
|
|
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
|
|
|
|
|
// add block arguments
|
|
|
|
|
auto blockArgumentType =
|
|
|
|
|
RankedTensorType::get({}, inputType.getElementType());
|
|
|
|
|
block.addArgument(blockArgumentType, loc);
|
|
|
|
|
block.addArgument(blockArgumentType, loc);
|
|
|
|
|
|
|
|
|
|
auto *lhsArg = block.args_begin();
|
|
|
|
|
auto *rhsArg = std::next(lhsArg);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
if (reduceType == 0) {
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(loc, *rhsArg);
|
|
|
|
|
} else if (reduceType == 1) {
|
|
|
|
|
Value res = rewriter.create<stablehlo::AddOp>(loc, blockArgumentType,
|
|
|
|
|
*lhsArg, *rhsArg);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(loc, res);
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-07-24 10:14:45 +08:00
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
rewriter.replaceOp(op, stablehloScatterOp.getResults());
|
|
|
|
|
return success();
|
2023-07-24 10:14:45 +08:00
|
|
|
|
}
|
2024-04-01 19:39:49 +08:00
|
|
|
|
};
|
2023-07-24 10:14:45 +08:00
|
|
|
|
|
2023-05-25 02:13:57 +08:00
|
|
|
|
// AtenIndexTensorOp
|
2024-04-01 19:39:49 +08:00
|
|
|
|
// Convert to StableHlo::GatherOp.
|
2023-05-25 02:13:57 +08:00
|
|
|
|
template <>
|
2023-08-15 19:36:08 +08:00
|
|
|
|
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|
|
|
|
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
|
2023-05-25 02:13:57 +08:00
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto outShape = outType.getShape();
|
2023-05-25 02:13:57 +08:00
|
|
|
|
Value indexList = op.getIndices();
|
|
|
|
|
SmallVector<Value> indicesTorchType;
|
|
|
|
|
if (!getListConstructElements(indexList, indicesTorchType))
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"unimplemented: the tensor list is not from list construct");
|
|
|
|
|
|
|
|
|
|
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
|
|
|
|
indicesTorchType);
|
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
int maxIndexRank = -1;
|
|
|
|
|
auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors,
|
|
|
|
|
outShape, maxIndexRank);
|
|
|
|
|
if (failed(gatherIndicesInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to generate broadcasted indices");
|
2023-05-25 02:13:57 +08:00
|
|
|
|
}
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto gatherIndices = *gatherIndicesInfo;
|
2023-05-25 02:13:57 +08:00
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
int64_t numIndicesDim = indexTensors.size();
|
|
|
|
|
int64_t indexVecDim = maxIndexRank;
|
2023-05-25 02:13:57 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> offsetDims;
|
|
|
|
|
SmallVector<int64_t> collapsedDims;
|
|
|
|
|
SmallVector<int64_t> startIndexMap;
|
|
|
|
|
for (int64_t i = 0; i < numIndicesDim; ++i) {
|
|
|
|
|
collapsedDims.push_back(i);
|
|
|
|
|
startIndexMap.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
|
2024-04-01 19:39:49 +08:00
|
|
|
|
offsetDims.push_back(i + maxIndexRank - numIndicesDim);
|
2023-05-25 02:13:57 +08:00
|
|
|
|
}
|
|
|
|
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*offsetDims=*/offsetDims,
|
|
|
|
|
/*collapsedSliceDims=*/collapsedDims,
|
2024-05-22 23:28:45 +08:00
|
|
|
|
/*operandBatchingDims=*/{},
|
|
|
|
|
/*startIndicesBatchingDims=*/{},
|
2023-05-25 02:13:57 +08:00
|
|
|
|
/*startIndexMap=*/startIndexMap,
|
|
|
|
|
/*indexVecDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> sliceSizes;
|
|
|
|
|
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
|
|
|
|
|
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
|
|
|
|
|
if (i < numIndicesDim) {
|
|
|
|
|
sliceSizes.push_back(1);
|
|
|
|
|
} else {
|
|
|
|
|
sliceSizes.push_back(inputShape[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
2024-04-01 19:39:49 +08:00
|
|
|
|
op, outType, input, gatherIndices, dimsAttr,
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
|
rewriter.getDenseI64ArrayAttr(sliceSizes));
|
2023-05-25 02:13:57 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2024-04-01 19:39:49 +08:00
|
|
|
|
// AtenIndexPutHackedTwinOP
|
|
|
|
|
// Convert to stablehlo::ScatterOp
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|
|
|
|
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
|
Value values = adaptor.getValues();
|
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
2024-04-01 19:39:49 +08:00
|
|
|
|
int64_t inputRank = inputType.getRank();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto valuesType = cast<RankedTensorType>(values.getType());
|
2024-04-01 19:39:49 +08:00
|
|
|
|
auto valuesShape = valuesType.getShape();
|
|
|
|
|
bool accumulate;
|
|
|
|
|
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"accumulate should be a constant bool");
|
|
|
|
|
}
|
|
|
|
|
Value indexList = op.getIndices();
|
|
|
|
|
SmallVector<Value> indicesTorchType;
|
|
|
|
|
if (!getListConstructElements(indexList, indicesTorchType))
|
|
|
|
|
return op.emitError(
|
|
|
|
|
"unimplemented: the tensor list is not from list construct");
|
|
|
|
|
|
|
|
|
|
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
|
|
|
|
indicesTorchType);
|
|
|
|
|
|
|
|
|
|
int maxIndexRank = -1;
|
|
|
|
|
auto scatterIndicesInfo = broadcastAndConcatIndices(
|
|
|
|
|
op, rewriter, indexTensors, valuesShape, maxIndexRank);
|
|
|
|
|
if (failed(scatterIndicesInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to generate broadcasted indices");
|
|
|
|
|
}
|
|
|
|
|
auto scatterIndices = *scatterIndicesInfo;
|
|
|
|
|
|
|
|
|
|
// create stablehlo::ScatterOp
|
|
|
|
|
int64_t indexVecDim = maxIndexRank;
|
|
|
|
|
SmallVector<int64_t> scatterDimOperandDimMap;
|
|
|
|
|
SmallVector<int64_t> insertedWindowDims;
|
|
|
|
|
SmallVector<int64_t> updateWindowDims;
|
|
|
|
|
for (int64_t i = 0; i < maxIndexRank; ++i) {
|
|
|
|
|
scatterDimOperandDimMap.push_back(i);
|
|
|
|
|
insertedWindowDims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
|
|
|
|
updateWindowDims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n";
|
|
|
|
|
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*updateWindowDims=*/updateWindowDims,
|
|
|
|
|
/*insertedWindowDims=*/insertedWindowDims,
|
2024-05-22 23:28:45 +08:00
|
|
|
|
/*inputBatchingDims=*/{},
|
|
|
|
|
/*scatterIndicesBatchingDims=*/{},
|
2024-04-01 19:39:49 +08:00
|
|
|
|
/*scatterDimsToOperandDim=*/scatterDimOperandDimMap,
|
|
|
|
|
/*indexVectorDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
|
|
|
|
|
loc, outType, input, scatterIndices, values, scatterDimensionNumbers,
|
|
|
|
|
false, false);
|
|
|
|
|
|
|
|
|
|
// configure update computation function.
|
|
|
|
|
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
|
|
|
|
|
// add block arguments
|
|
|
|
|
auto blockArgumentType =
|
|
|
|
|
RankedTensorType::get({}, inputType.getElementType());
|
|
|
|
|
block.addArgument(blockArgumentType, loc);
|
|
|
|
|
block.addArgument(blockArgumentType, loc);
|
|
|
|
|
|
|
|
|
|
auto *lhsArg = block.args_begin();
|
|
|
|
|
auto *rhsArg = std::next(lhsArg);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
if (!accumulate) {
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(loc, *rhsArg);
|
|
|
|
|
} else {
|
|
|
|
|
Value out = rewriter.create<stablehlo::AddOp>(loc, blockArgumentType,
|
|
|
|
|
*lhsArg, *rhsArg);
|
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(loc, out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, stablehloScatterOp.getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2023-03-23 04:41:04 +08:00
|
|
|
|
void mlir::torch::torch_to_stablehlo::
|
|
|
|
|
populateGatherScatterOpPatternsAndLegality(
|
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
2022-07-25 23:47:46 +08:00
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
|
|
|
|
|
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
2022-07-25 23:47:46 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
2023-09-05 21:28:37 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenEmbeddingBagPaddingIdxOp);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
2022-09-25 22:07:46 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
2023-03-23 04:41:04 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
2023-08-15 19:36:08 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
2024-04-01 19:39:49 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#undef INSERT_ATENOP_PATTERN
|
2024-04-01 19:39:49 +08:00
|
|
|
|
|
|
|
|
|
#define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
|
patterns.add<ConvertAtenScatterOp<AtenOp, reduceType>>(typeConverter, \
|
|
|
|
|
context, options)
|
|
|
|
|
INSERT_ATEN_SCATTER_PATTERN(AtenScatterSrcOp, 0); // 0 for None reduce op
|
|
|
|
|
INSERT_ATEN_SCATTER_PATTERN(AtenScatterAddOp, 1); // 1 for Add reduce op
|
|
|
|
|
#undef INSERT_ATEN_SCATTER_PATTERN
|
2022-07-25 23:47:46 +08:00
|
|
|
|
}
|