[StableHLO] Support for slice_scatter (#1960)

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/1965/head snapshot-20230323.786
Zhekun Zhang 2023-03-22 13:41:04 -07:00 committed by GitHub
parent 544b5f232b
commit 5758a0bfbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 172 additions and 31 deletions

View File

@ -334,6 +334,8 @@ STABLEHLO_PASS_SET = {
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SliceStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
@ -342,6 +344,12 @@ STABLEHLO_PASS_SET = {
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeModule_broadcast",

View File

@ -89,6 +89,9 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -33,31 +33,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
Location loc, Value torchOptionalInt,
Value builtinInt, Value defaultValue,
Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>())
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
// positveDim < 0 ? 0 : positiveDim
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value atLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
return castIntToIndex(rewriter, loc, boundedByDimSize);
}
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,

View File

@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
GatherScatter.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp

View File

@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
sliceSizesTensor, dimsAttr)
.getResult();
}
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Location loc = op.getLoc();
auto input = adaptor.getSelf();
RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");
int64_t inputRank = inputType.getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];
Value torchTypeStart = op.getStart();
Value torchTypeEnd = op.getEnd();
Value builtinTypeStart = adaptor.getStart();
Value builtinTypeEnd = adaptor.getEnd();
if (torchTypeStart.getType().isa<OptionalType>() ||
torchTypeEnd.getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
if (!op.getStep().getType().template isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
}
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
builtinTypeStart, zero, dimSize);
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
dimSize, dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
resultShape = getTensorSizes(rewriter, loc, input);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
resultShape[dim] = resultSize;
strides.resize(inputType.getRank(), one);
offsets.resize(inputType.getRank(), zero);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
return success();
}
} // namespace
// Ref:
@ -258,9 +327,54 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
// AtenSliceScatterOp
template <>
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.getSelf();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
Value src = adaptor.getSrc();
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
void mlir::torch::torch_to_stablehlo::
populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \
@ -269,5 +383,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
#undef INSERT_ATENOP_PATTERN
}

View File

@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
void populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(

View File

@ -65,7 +65,7 @@ public:
typeConverter, patterns, target, options);
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);

View File

@ -324,6 +324,29 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm_unreachable("convertScalarToDtype should handle all the types");
}
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>())
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
// positiveDim < 0 ? 0 : positiveDim
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value atLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
return castIntToIndex(rewriter, loc, boundedByDimSize);
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -307,6 +307,23 @@ class SliceScatterZeroDimModule(torch.nn.Module):
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(1, 8))
class SliceScatterNegativeEndModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1)
@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule())
def SliceScatterNegativeEndModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(2, 8))
class SliceScatterNegativeDimModule(torch.nn.Module):