mirror of https://github.com/llvm/torch-mlir
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592)
parent
64b0d4aed3
commit
43e3118eb9
|
@ -630,32 +630,100 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
|||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.getSelf();
|
||||
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
}
|
||||
|
||||
int64_t inputRank = inputType.getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
}
|
||||
|
||||
auto inputShape = inputType.getShape();
|
||||
auto dimSize = inputShape[dim];
|
||||
int64_t step;
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
}
|
||||
|
||||
int64_t start;
|
||||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
||||
return op->emitError("unimplemented: start is not constant");
|
||||
} else if (ShapedType::isDynamic(dimSize) and start < 0) {
|
||||
return op->emitError("unimplemented: not support dynamic dimSize when "
|
||||
"start smaller than 0.");
|
||||
}
|
||||
start = start >= 0 ? start : dimSize + start;
|
||||
|
||||
int64_t end;
|
||||
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
||||
return op->emitError("unimplemented: end is not constant");
|
||||
} else if (ShapedType::isDynamic(dimSize) and end < 0) {
|
||||
return op->emitError(
|
||||
"unimplemented: not support dynamic dimSize when end smaller than 0.");
|
||||
}
|
||||
end = end >= 0 ? end : dimSize + end;
|
||||
|
||||
int64_t size = 0;
|
||||
std::vector<int64_t> indicesVec;
|
||||
for (int64_t i = start; i < end; i += step) {
|
||||
indicesVec.push_back(i);
|
||||
++size;
|
||||
}
|
||||
ArrayRef<int64_t> indices(indicesVec);
|
||||
std::vector<int64_t> tmp_shape = {size, 1};
|
||||
ArrayRef<int64_t> shape(tmp_shape);
|
||||
RankedTensorType constType =
|
||||
RankedTensorType::get(shape, rewriter.getIntegerType(64));
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
RankedTensorType::get(shape, rewriter.getIntegerType(64)), indices);
|
||||
auto const_op =
|
||||
rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
|
||||
Value scatterIndices = const_op.getResult();
|
||||
|
||||
SmallVector<int64_t> updateWindowDims;
|
||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||
if (i == dim) {
|
||||
continue;
|
||||
}
|
||||
updateWindowDims.push_back(i);
|
||||
}
|
||||
|
||||
auto scatterArgs = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*updateWindowDims=*/updateWindowDims,
|
||||
/*insertedWindowDims=*/{dim},
|
||||
/*inputBatchingDims=*/{},
|
||||
/*scatterIndicesBatchingDims=*/{},
|
||||
/*scatterDimsToOperandDim=*/{dim},
|
||||
/*indexVectorDim=*/1);
|
||||
|
||||
Value src = adaptor.getSrc();
|
||||
auto srcType = cast<RankedTensorType>(src.getType());
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||
auto scatterOp = rewriter.create<stablehlo::ScatterOp>(
|
||||
loc, resultType, input, scatterIndices, src, scatterArgs, false, false);
|
||||
|
||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||
Block &block = scatterOp.getUpdateComputation().emplaceBlock();
|
||||
auto blockArgumentType =
|
||||
RankedTensorType::get({}, inputType.getElementType());
|
||||
block.addArgument(blockArgumentType, loc);
|
||||
block.addArgument(blockArgumentType, loc);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
auto *lhs = block.args_begin();
|
||||
auto *rhs = std::next(lhs);
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
rewriter.create<stablehlo::ReturnOp>(loc, *rhs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, scatterOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1328,12 +1328,6 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceScatterModule_basic",
|
||||
"SliceScatterNegativeDimModule_basic",
|
||||
"SliceScatterNegativeEndModule_basic",
|
||||
"SliceScatterStaticModule_basic",
|
||||
"SliceScatterStepVariationModule_basic",
|
||||
"SliceScatterZeroDimModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
|
@ -1464,7 +1458,6 @@ STABLEHLO_PASS_SET = {
|
|||
"RandModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SqueezeDimModule_static",
|
||||
"SqueezeModule_static",
|
||||
"TriuBroadcastModule_basic",
|
||||
|
|
Loading…
Reference in New Issue