mirror of https://github.com/llvm/torch-mlir
[StableHLO] Support for slice_scatter (#1960)
Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/1965/head snapshot-20230323.786
parent
544b5f232b
commit
5758a0bfbb
|
@ -334,6 +334,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"RsubIntModule_basic",
|
"RsubIntModule_basic",
|
||||||
"RsubIntModule_noalpha_basic",
|
"RsubIntModule_noalpha_basic",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
|
"SelectScattertModule_basic",
|
||||||
|
"SelectScattertStaticModule_basic",
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
"SliceModule_basic",
|
"SliceModule_basic",
|
||||||
"SliceNegIdxModule_basic",
|
"SliceNegIdxModule_basic",
|
||||||
|
@ -342,6 +344,12 @@ STABLEHLO_PASS_SET = {
|
||||||
"SliceStartEqEndModule_basic",
|
"SliceStartEqEndModule_basic",
|
||||||
"SliceSizeTwoStepModule_basic",
|
"SliceSizeTwoStepModule_basic",
|
||||||
"SliceWholeTensorModule_basic",
|
"SliceWholeTensorModule_basic",
|
||||||
|
"SliceScatterModule_basic",
|
||||||
|
"SliceScatterNegativeDimModule_basic",
|
||||||
|
"SliceScatterNegativeEndModule_basic",
|
||||||
|
"SliceScatterStaticModule_basic",
|
||||||
|
"SliceScatterStepVariationModule_basic",
|
||||||
|
"SliceScatterZeroDimModule_basic",
|
||||||
"SqueezeDimModule_static",
|
"SqueezeDimModule_static",
|
||||||
"SqueezeDimModule_identity",
|
"SqueezeDimModule_identity",
|
||||||
"SqueezeModule_broadcast",
|
"SqueezeModule_broadcast",
|
||||||
|
|
|
@ -89,6 +89,9 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
||||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
std::optional<Type> srcOriginalDtype = std::nullopt);
|
std::optional<Type> srcOriginalDtype = std::nullopt);
|
||||||
|
|
||||||
|
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value torchOptionalInt, Value builtinInt,
|
||||||
|
Value defaultValue, Value dimSize);
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -33,31 +33,6 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, Value torchOptionalInt,
|
|
||||||
Value builtinInt, Value defaultValue,
|
|
||||||
Value dimSize) {
|
|
||||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
|
||||||
return defaultValue;
|
|
||||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
|
||||||
Value positiveDim =
|
|
||||||
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
|
||||||
// positveDim < 0 ? 0 : positiveDim
|
|
||||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
|
||||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
|
||||||
Value atLeastZero =
|
|
||||||
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
|
||||||
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
|
||||||
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
|
||||||
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
|
||||||
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
|
||||||
|
|
||||||
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename OpTy, typename OpAdaptor>
|
template <typename OpTy, typename OpAdaptor>
|
||||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
|
|
|
@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
TorchToStablehlo.cpp
|
TorchToStablehlo.cpp
|
||||||
StablehloLegalizeUtils.cpp
|
StablehloLegalizeUtils.cpp
|
||||||
Basic.cpp
|
Basic.cpp
|
||||||
Gather.cpp
|
GatherScatter.cpp
|
||||||
Linear.cpp
|
Linear.cpp
|
||||||
ViewLike.cpp
|
ViewLike.cpp
|
||||||
Reduction.cpp
|
Reduction.cpp
|
||||||
|
|
|
@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
sliceSizesTensor, dimsAttr)
|
sliceSizesTensor, dimsAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpTy, typename OpAdaptor>
|
||||||
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
SmallVector<Value> &resultShape,
|
||||||
|
SmallVector<Value> &offsets,
|
||||||
|
SmallVector<Value> &strides) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto input = adaptor.getSelf();
|
||||||
|
RankedTensorType inputType =
|
||||||
|
input.getType().template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return op->emitError("unimplemented: dim is not constant");
|
||||||
|
|
||||||
|
int64_t inputRank = inputType.getRank();
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
||||||
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value dimSize = inputShape[dim];
|
||||||
|
|
||||||
|
Value torchTypeStart = op.getStart();
|
||||||
|
Value torchTypeEnd = op.getEnd();
|
||||||
|
Value builtinTypeStart = adaptor.getStart();
|
||||||
|
Value builtinTypeEnd = adaptor.getEnd();
|
||||||
|
|
||||||
|
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||||
|
torchTypeEnd.getType().isa<OptionalType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||||
|
|
||||||
|
int64_t step;
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||||
|
if (!op.getStep().getType().template isa<Torch::NoneType>())
|
||||||
|
return op->emitError("unimplemented: step is not constant");
|
||||||
|
step = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||||||
|
builtinTypeStart, zero, dimSize);
|
||||||
|
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||||||
|
dimSize, dimSize);
|
||||||
|
|
||||||
|
// end >= start ? end : start
|
||||||
|
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sge, end, start);
|
||||||
|
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||||
|
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||||
|
|
||||||
|
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||||
|
resultShape = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||||
|
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||||
|
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||||
|
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||||
|
resultShape[dim] = resultSize;
|
||||||
|
|
||||||
|
strides.resize(inputType.getRank(), one);
|
||||||
|
offsets.resize(inputType.getRank(), zero);
|
||||||
|
|
||||||
|
offsets[dim] = start;
|
||||||
|
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Ref:
|
// Ref:
|
||||||
|
@ -258,7 +327,52 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
// AtenSliceScatterOp
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
|
AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
|
auto input = adaptor.getSelf();
|
||||||
|
|
||||||
|
RankedTensorType resultType =
|
||||||
|
typeConverter->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
SmallVector<Value> offsets;
|
||||||
|
SmallVector<Value> strides;
|
||||||
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||||
|
AtenSliceScatterOpAdaptor>(
|
||||||
|
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value src = adaptor.getSrc();
|
||||||
|
auto srcType = src.getType().cast<RankedTensorType>();
|
||||||
|
int64_t srcRank = srcType.getRank();
|
||||||
|
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||||
|
auto abstractSrcType = RankedTensorType::get(
|
||||||
|
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
||||||
|
Value abstractSrc =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||||
|
|
||||||
|
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||||
|
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::torch_to_stablehlo::
|
||||||
|
populateGatherScatterOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
@ -269,5 +383,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
}
|
}
|
|
@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
void populateViewLikeOpPatternsAndLegality(
|
void populateViewLikeOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
void populateGatherOpPatternsAndLegality(
|
void populateGatherScatterOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
void populateReductionOpPatternsAndLegality(
|
void populateReductionOpPatternsAndLegality(
|
||||||
|
|
|
@ -65,7 +65,7 @@ public:
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality(
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
|
|
|
@ -324,6 +324,29 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value torchOptionalInt, Value builtinInt,
|
||||||
|
Value defaultValue, Value dimSize) {
|
||||||
|
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
||||||
|
return defaultValue;
|
||||||
|
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||||
|
Value positiveDim =
|
||||||
|
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
||||||
|
// positiveDim < 0 ? 0 : positiveDim
|
||||||
|
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||||
|
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||||
|
Value atLeastZero =
|
||||||
|
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||||
|
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
||||||
|
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
||||||
|
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||||
|
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
||||||
|
|
||||||
|
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
||||||
|
}
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -307,6 +307,23 @@ class SliceScatterZeroDimModule(torch.nn.Module):
|
||||||
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
|
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
||||||
|
|
||||||
|
class SliceScatterNegativeEndModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule())
|
||||||
|
def SliceScatterNegativeEndModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 8), tu.rand(2, 8))
|
||||||
|
|
||||||
class SliceScatterNegativeDimModule(torch.nn.Module):
|
class SliceScatterNegativeDimModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue