[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592)

pull/3638/head
yyp0 2024-08-15 20:06:29 +08:00 committed by GitHub
parent 64b0d4aed3
commit 43e3118eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 24 deletions

View File

@ -630,32 +630,100 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
const TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.getSelf();
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return op->emitError("unimplemented: dim is not constant");
}
int64_t inputRank = inputType.getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
}
auto inputShape = inputType.getShape();
auto dimSize = inputShape[dim];
int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
return op->emitError("unimplemented: step is not constant");
}
int64_t start;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
return op->emitError("unimplemented: start is not constant");
} else if (ShapedType::isDynamic(dimSize) and start < 0) {
return op->emitError("unimplemented: not support dynamic dimSize when "
"start smaller than 0.");
}
start = start >= 0 ? start : dimSize + start;
int64_t end;
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
return op->emitError("unimplemented: end is not constant");
} else if (ShapedType::isDynamic(dimSize) and end < 0) {
return op->emitError(
"unimplemented: not support dynamic dimSize when end smaller than 0.");
}
end = end >= 0 ? end : dimSize + end;
int64_t size = 0;
std::vector<int64_t> indicesVec;
for (int64_t i = start; i < end; i += step) {
indicesVec.push_back(i);
++size;
}
ArrayRef<int64_t> indices(indicesVec);
std::vector<int64_t> tmp_shape = {size, 1};
ArrayRef<int64_t> shape(tmp_shape);
RankedTensorType constType =
RankedTensorType::get(shape, rewriter.getIntegerType(64));
auto constAttr = DenseElementsAttr::get(
RankedTensorType::get(shape, rewriter.getIntegerType(64)), indices);
auto const_op =
rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
Value scatterIndices = const_op.getResult();
SmallVector<int64_t> updateWindowDims;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
continue;
}
updateWindowDims.push_back(i);
}
auto scatterArgs = stablehlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
/*updateWindowDims=*/updateWindowDims,
/*insertedWindowDims=*/{dim},
/*inputBatchingDims=*/{},
/*scatterIndicesBatchingDims=*/{},
/*scatterDimsToOperandDim=*/{dim},
/*indexVectorDim=*/1);
Value src = adaptor.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
auto scatterOp = rewriter.create<stablehlo::ScatterOp>(
loc, resultType, input, scatterIndices, src, scatterArgs, false, false);
Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);
Block &block = scatterOp.getUpdateComputation().emplaceBlock();
auto blockArgumentType =
RankedTensorType::get({}, inputType.getElementType());
block.addArgument(blockArgumentType, loc);
block.addArgument(blockArgumentType, loc);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
auto *lhs = block.args_begin();
auto *rhs = std::next(lhs);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
rewriter.create<stablehlo::ReturnOp>(loc, *rhs);
}
rewriter.replaceOp(op, scatterOp.getResults());
return success();
}

View File

@ -1328,12 +1328,6 @@ STABLEHLO_PASS_SET = {
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceStaticModule_basic",
@ -1464,7 +1458,6 @@ STABLEHLO_PASS_SET = {
"RandModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertStaticModule_basic",
"SqueezeDimModule_static",
"SqueezeModule_static",
"TriuBroadcastModule_basic",