From 5758a0bfbbbb34aa7dd2f9e9cc029c4c962e05ed Mon Sep 17 00:00:00 2001 From: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com> Date: Wed, 22 Mar 2023 13:41:04 -0700 Subject: [PATCH] [StableHLO] Support for slice_scatter (#1960) Co-authored-by: zhekun.zhang --- e2e_testing/xfail_sets.py | 8 ++ include/torch-mlir/Conversion/Utils/Utils.h | 3 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 25 ---- .../TorchToStablehlo/CMakeLists.txt | 2 +- .../{Gather.cpp => GatherScatter.cpp} | 121 +++++++++++++++++- .../TorchToStablehlo/PopulatePatterns.h | 2 +- .../TorchToStablehlo/TorchToStablehlo.cpp | 2 +- lib/Conversion/Utils/Utils.cpp | 23 ++++ .../test_suite/slice_like.py | 17 +++ 9 files changed, 172 insertions(+), 31 deletions(-) rename lib/Conversion/TorchToStablehlo/{Gather.cpp => GatherScatter.cpp} (69%) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 01d323518..2022ae04c 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -334,6 +334,8 @@ STABLEHLO_PASS_SET = { "RsubIntModule_basic", "RsubIntModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", "SliceStaticModule_basic", "SliceModule_basic", "SliceNegIdxModule_basic", @@ -342,6 +344,12 @@ STABLEHLO_PASS_SET = { "SliceStartEqEndModule_basic", "SliceSizeTwoStepModule_basic", "SliceWholeTensorModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", "SqueezeDimModule_static", "SqueezeDimModule_identity", "SqueezeModule_broadcast", diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index adafe173c..15484e5f5 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -89,6 +89,9 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype = std::nullopt); +Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, + Value torchOptionalInt, Value builtinInt, + Value defaultValue, Value dimSize); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 293649de5..d261835b5 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -33,31 +33,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static Value toPositiveValidDim(ConversionPatternRewriter &rewriter, - Location loc, Value torchOptionalInt, - Value builtinInt, Value defaultValue, - Value dimSize) { - if (torchOptionalInt.getType().isa()) - return defaultValue; - auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); - Value positiveDim = - toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); - // positveDim < 0 ? 0 : positiveDim - Value cst0 = rewriter.create( - loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); - Value predDimSltZero = rewriter.create( - loc, arith::CmpIPredicate::slt, positiveDim, cst0); - Value atLeastZero = - rewriter.create(loc, predDimSltZero, cst0, positiveDim); - // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero - Value sgtDimSize = rewriter.create( - loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); - Value boundedByDimSize = rewriter.create( - loc, sgtDimSize, dimSizeAsInt, atLeastZero); - - return castIntToIndex(rewriter, loc, boundedByDimSize); -} - template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 237512980..84a560cd7 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo TorchToStablehlo.cpp StablehloLegalizeUtils.cpp Basic.cpp - Gather.cpp + GatherScatter.cpp Linear.cpp ViewLike.cpp Reduction.cpp diff --git a/lib/Conversion/TorchToStablehlo/Gather.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp similarity index 69% rename from lib/Conversion/TorchToStablehlo/Gather.cpp rename to lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 437332703..2230e714b 100644 --- a/lib/Conversion/TorchToStablehlo/Gather.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, sliceSizesTensor, dimsAttr) .getResult(); } + +template +LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &resultShape, + SmallVector &offsets, + SmallVector &strides) { + Location loc = op.getLoc(); + auto input = adaptor.getSelf(); + RankedTensorType inputType = + input.getType().template cast(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return op->emitError("unimplemented: dim is not constant"); + + int64_t inputRank = inputType.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value dimSize = inputShape[dim]; + + Value torchTypeStart = op.getStart(); + Value torchTypeEnd = op.getEnd(); + Value builtinTypeStart = adaptor.getStart(); + Value builtinTypeEnd = adaptor.getEnd(); + + if (torchTypeStart.getType().isa() || + torchTypeEnd.getType().isa()) + return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { + if (!op.getStep().getType().template isa()) + return op->emitError("unimplemented: step is not constant"); + step = 1; + } + + Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, + builtinTypeStart, zero, dimSize); + Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, + dimSize, dimSize); + + // end >= start ? end : start + Value endSgeStart = rewriter.create( + loc, arith::CmpIPredicate::sge, end, start); + end = rewriter.create(loc, endSgeStart, end, start); + Value stepIndex = rewriter.create(loc, step); + + // Slice logic: resultSize = floordiv(end - start + step - 1, step) + resultShape = getTensorSizes(rewriter, loc, input); + Value len = rewriter.create(loc, end, start); + Value resultSize = rewriter.create(loc, len, stepIndex); + resultSize = rewriter.create(loc, resultSize, one); + resultSize = rewriter.create(loc, resultSize, stepIndex); + resultShape[dim] = resultSize; + + strides.resize(inputType.getRank(), one); + offsets.resize(inputType.getRank(), zero); + + offsets[dim] = start; + strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + return success(); +} } // namespace // Ref: @@ -258,9 +327,54 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { +// AtenSliceScatterOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + TypeConverter *typeConverter = getTypeConverter(); + + auto input = adaptor.getSelf(); + + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + SmallVector resultShape; + SmallVector offsets; + SmallVector strides; + if (failed(prepareArgumentsForSlicingOp( + op, adaptor, rewriter, resultShape, offsets, strides))) { + return failure(); + } + + Value src = adaptor.getSrc(); + auto srcType = src.getType().cast(); + int64_t srcRank = srcType.getRank(); + SmallVector srcAbstractSizes(srcRank, kUnknownSize); + auto abstractSrcType = RankedTensorType::get( + makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); + Value abstractSrc = + rewriter.create(loc, abstractSrcType, src); + + Value result = rewriter.create( + loc, abstractSrc, input, offsets, resultShape, strides); + + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); +} + +void mlir::torch::torch_to_stablehlo:: + populateGatherScatterOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ @@ -269,5 +383,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h index b6322efd6..fc28acfde 100644 --- a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h @@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, void populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); -void populateGatherOpPatternsAndLegality( +void populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); void populateReductionOpPatternsAndLegality( diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index ba0838484..434d55c76 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -65,7 +65,7 @@ public: typeConverter, patterns, target, options); torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( typeConverter, patterns, target, options); - torch_to_stablehlo::populateGatherOpPatternsAndLegality( + torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality( typeConverter, patterns, target, options); torch_to_stablehlo::populateReductionOpPatternsAndLegality( typeConverter, patterns, target, options); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 906cc3c44..474032f77 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -324,6 +324,29 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, llvm_unreachable("convertScalarToDtype should handle all the types"); } +Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, + Value torchOptionalInt, Value builtinInt, + Value defaultValue, Value dimSize) { + if (torchOptionalInt.getType().isa()) + return defaultValue; + auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); + Value positiveDim = + toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); + // positiveDim < 0 ? 0 : positiveDim + Value cst0 = rewriter.create( + loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); + Value predDimSltZero = rewriter.create( + loc, arith::CmpIPredicate::slt, positiveDim, cst0); + Value atLeastZero = + rewriter.create(loc, predDimSltZero, cst0, positiveDim); + // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero + Value sgtDimSize = rewriter.create( + loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); + Value boundedByDimSize = rewriter.create( + loc, sgtDimSize, dimSizeAsInt, atLeastZero); + + return castIntToIndex(rewriter, loc, boundedByDimSize); +} } // namespace Torch } // namespace torch } // namespace mlir diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 1e8566826..08cb00e19 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -307,6 +307,23 @@ class SliceScatterZeroDimModule(torch.nn.Module): def SliceScatterZeroDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(1, 8)) +class SliceScatterNegativeEndModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1) + + +@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule()) +def SliceScatterNegativeEndModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(2, 8)) class SliceScatterNegativeDimModule(torch.nn.Module):