mirror of https://github.com/llvm/torch-mlir
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592)
parent
64b0d4aed3
commit
43e3118eb9
|
@ -630,32 +630,100 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
|
||||||
RankedTensorType resultType = cast<RankedTensorType>(
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType()));
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
int64_t dim;
|
||||||
SmallVector<Value> offsets;
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||||
SmallVector<Value> strides;
|
return op->emitError("unimplemented: dim is not constant");
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
|
||||||
AtenSliceScatterOpAdaptor>(
|
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t inputRank = inputType.getRank();
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
auto dimSize = inputShape[dim];
|
||||||
|
int64_t step;
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||||
|
return op->emitError("unimplemented: step is not constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t start;
|
||||||
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
||||||
|
return op->emitError("unimplemented: start is not constant");
|
||||||
|
} else if (ShapedType::isDynamic(dimSize) and start < 0) {
|
||||||
|
return op->emitError("unimplemented: not support dynamic dimSize when "
|
||||||
|
"start smaller than 0.");
|
||||||
|
}
|
||||||
|
start = start >= 0 ? start : dimSize + start;
|
||||||
|
|
||||||
|
int64_t end;
|
||||||
|
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
||||||
|
return op->emitError("unimplemented: end is not constant");
|
||||||
|
} else if (ShapedType::isDynamic(dimSize) and end < 0) {
|
||||||
|
return op->emitError(
|
||||||
|
"unimplemented: not support dynamic dimSize when end smaller than 0.");
|
||||||
|
}
|
||||||
|
end = end >= 0 ? end : dimSize + end;
|
||||||
|
|
||||||
|
int64_t size = 0;
|
||||||
|
std::vector<int64_t> indicesVec;
|
||||||
|
for (int64_t i = start; i < end; i += step) {
|
||||||
|
indicesVec.push_back(i);
|
||||||
|
++size;
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> indices(indicesVec);
|
||||||
|
std::vector<int64_t> tmp_shape = {size, 1};
|
||||||
|
ArrayRef<int64_t> shape(tmp_shape);
|
||||||
|
RankedTensorType constType =
|
||||||
|
RankedTensorType::get(shape, rewriter.getIntegerType(64));
|
||||||
|
auto constAttr = DenseElementsAttr::get(
|
||||||
|
RankedTensorType::get(shape, rewriter.getIntegerType(64)), indices);
|
||||||
|
auto const_op =
|
||||||
|
rewriter.create<stablehlo::ConstantOp>(loc, constType, constAttr);
|
||||||
|
Value scatterIndices = const_op.getResult();
|
||||||
|
|
||||||
|
SmallVector<int64_t> updateWindowDims;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||||
|
if (i == dim) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
updateWindowDims.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scatterArgs = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||||
|
rewriter.getContext(),
|
||||||
|
/*updateWindowDims=*/updateWindowDims,
|
||||||
|
/*insertedWindowDims=*/{dim},
|
||||||
|
/*inputBatchingDims=*/{},
|
||||||
|
/*scatterIndicesBatchingDims=*/{},
|
||||||
|
/*scatterDimsToOperandDim=*/{dim},
|
||||||
|
/*indexVectorDim=*/1);
|
||||||
|
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
auto srcType = cast<RankedTensorType>(src.getType());
|
auto scatterOp = rewriter.create<stablehlo::ScatterOp>(
|
||||||
int64_t srcRank = srcType.getRank();
|
loc, resultType, input, scatterIndices, src, scatterArgs, false, false);
|
||||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
|
||||||
auto abstractSrcType = RankedTensorType::get(
|
|
||||||
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
|
||||||
Value abstractSrc =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
|
||||||
|
|
||||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
Block &block = scatterOp.getUpdateComputation().emplaceBlock();
|
||||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
auto blockArgumentType =
|
||||||
|
RankedTensorType::get({}, inputType.getElementType());
|
||||||
|
block.addArgument(blockArgumentType, loc);
|
||||||
|
block.addArgument(blockArgumentType, loc);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
auto *lhs = block.args_begin();
|
||||||
|
auto *rhs = std::next(lhs);
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(loc, *rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, scatterOp.getResults());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1328,12 +1328,6 @@ STABLEHLO_PASS_SET = {
|
||||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||||
"SliceOutOfUpperBoundIndexModule_basic",
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
"SliceScatterModule_basic",
|
|
||||||
"SliceScatterNegativeDimModule_basic",
|
|
||||||
"SliceScatterNegativeEndModule_basic",
|
|
||||||
"SliceScatterStaticModule_basic",
|
|
||||||
"SliceScatterStepVariationModule_basic",
|
|
||||||
"SliceScatterZeroDimModule_basic",
|
|
||||||
"SliceSizeTwoStepModule_basic",
|
"SliceSizeTwoStepModule_basic",
|
||||||
"SliceStartEqEndModule_basic",
|
"SliceStartEqEndModule_basic",
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
|
@ -1464,7 +1458,6 @@ STABLEHLO_PASS_SET = {
|
||||||
"RandModule_basic",
|
"RandModule_basic",
|
||||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
"SelectScattertStaticModule_basic",
|
|
||||||
"SqueezeDimModule_static",
|
"SqueezeDimModule_static",
|
||||||
"SqueezeModule_static",
|
"SqueezeModule_static",
|
||||||
"TriuBroadcastModule_basic",
|
"TriuBroadcastModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue