mirror of https://github.com/llvm/torch-mlir
[StableHLO] Support for slice_scatter (#1960)
Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/1965/head snapshot-20230323.786
parent
544b5f232b
commit
5758a0bfbb
|
@ -334,6 +334,8 @@ STABLEHLO_PASS_SET = {
|
|||
"RsubIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
|
@ -342,6 +344,12 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceStartEqEndModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"SliceScatterModule_basic",
|
||||
"SliceScatterNegativeDimModule_basic",
|
||||
"SliceScatterNegativeEndModule_basic",
|
||||
"SliceScatterStaticModule_basic",
|
||||
"SliceScatterStepVariationModule_basic",
|
||||
"SliceScatterZeroDimModule_basic",
|
||||
"SqueezeDimModule_static",
|
||||
"SqueezeDimModule_identity",
|
||||
"SqueezeModule_broadcast",
|
||||
|
|
|
@ -89,6 +89,9 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
std::optional<Type> srcOriginalDtype = std::nullopt);
|
||||
|
||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value torchOptionalInt, Value builtinInt,
|
||||
Value defaultValue, Value dimSize);
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -33,31 +33,6 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value torchOptionalInt,
|
||||
Value builtinInt, Value defaultValue,
|
||||
Value dimSize) {
|
||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
||||
return defaultValue;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value positiveDim =
|
||||
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
||||
// positveDim < 0 ? 0 : positiveDim
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||
Value atLeastZero =
|
||||
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
||||
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
||||
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
||||
}
|
||||
|
||||
template <typename OpTy, typename OpAdaptor>
|
||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
|
|
|
@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
|||
TorchToStablehlo.cpp
|
||||
StablehloLegalizeUtils.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
GatherScatter.cpp
|
||||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
|
|
|
@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
sliceSizesTensor, dimsAttr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
template <typename OpTy, typename OpAdaptor>
|
||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value> &resultShape,
|
||||
SmallVector<Value> &offsets,
|
||||
SmallVector<Value> &strides) {
|
||||
Location loc = op.getLoc();
|
||||
auto input = adaptor.getSelf();
|
||||
RankedTensorType inputType =
|
||||
input.getType().template cast<RankedTensorType>();
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
int64_t inputRank = inputType.getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
Value dimSize = inputShape[dim];
|
||||
|
||||
Value torchTypeStart = op.getStart();
|
||||
Value torchTypeEnd = op.getEnd();
|
||||
Value builtinTypeStart = adaptor.getStart();
|
||||
Value builtinTypeEnd = adaptor.getEnd();
|
||||
|
||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||
torchTypeEnd.getType().isa<OptionalType>())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||
if (!op.getStep().getType().template isa<Torch::NoneType>())
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
step = 1;
|
||||
}
|
||||
|
||||
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||||
builtinTypeStart, zero, dimSize);
|
||||
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||||
dimSize, dimSize);
|
||||
|
||||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||
|
||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||
resultShape = getTensorSizes(rewriter, loc, input);
|
||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||
resultShape[dim] = resultSize;
|
||||
|
||||
strides.resize(inputType.getRank(), one);
|
||||
offsets.resize(inputType.getRank(), zero);
|
||||
|
||||
offsets[dim] = start;
|
||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Ref:
|
||||
|
@ -258,7 +327,52 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
// AtenSliceScatterOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||
AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value src = adaptor.getSrc();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||
|
||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::
|
||||
populateGatherScatterOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
@ -269,5 +383,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
|
@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
|||
void populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(
|
||||
void populateGatherScatterOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(
|
||||
|
|
|
@ -65,7 +65,7 @@ public:
|
|||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
|
|
|
@ -324,6 +324,29 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value torchOptionalInt, Value builtinInt,
|
||||
Value defaultValue, Value dimSize) {
|
||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
||||
return defaultValue;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value positiveDim =
|
||||
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
||||
// positiveDim < 0 ? 0 : positiveDim
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||
Value atLeastZero =
|
||||
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
||||
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
||||
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
||||
}
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -307,6 +307,23 @@ class SliceScatterZeroDimModule(torch.nn.Module):
|
|||
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
||||
|
||||
class SliceScatterNegativeEndModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule())
|
||||
def SliceScatterNegativeEndModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(2, 8))
|
||||
|
||||
class SliceScatterNegativeDimModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue