diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index af7b91d17..de68e6146 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2963,6 +2963,128 @@ public: }; } // namespace +namespace { +// def slice_scatter(self, values, dim, start, end, step): +// size = self.size(dim) +// indices = torch.arange(size) +// shift_indices = indices - start +// mask = shift_indices % step == 0 +// start_mask = shift_indices >= 0 +// end_mask = shift_indices < end +// mask = mask * start_mask +// mask = mask * end_mask +// sizes = list(self.size()) +// rank = len(sizes) +// shape = [1] * rank +// shape[dim] = size +// mask = mask.view(shape) +// return torch.where(mask, values, self) +// +class DecomposeAtenSliceScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSliceScatterOp op, + PatternRewriter &rewriter) const override { + int64_t inputRank = getTensorRank(op.self()); + int64_t dimInt = 0; + if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } else { + return rewriter.notifyMatchFailure(op, "dim must be constant"); + } + + auto getOptionalVal = [&](Value val, Value defVal) -> Value { + if (val.getType().isa()) { + return defVal; + } else { + return val; + } + }; + + Value one = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + Value zero = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(0)); + Value none = rewriter.create(op.getLoc()); + Value dimSize = + rewriter.create(op.getLoc(), op.self(), op.dim()); + + Value start = getOptionalVal(op.start(), zero); + Value end = getOptionalVal(op.end(), dimSize); + Value step = getOptionalVal(op.step(), one); + // Step 0. create indices + Type indicesType = ValueTensorType::get( + op.getContext(), ArrayRef{ShapedType::kDynamicSize}, + IntegerType::get(op.getContext(), 64, IntegerType::Signed)); + Value indices = rewriter.create( + op.getLoc(), indicesType, dimSize, none, none, none, none); + + // Step 1. make indices broadcastable to self's shape + SmallVector newIndicesShapeInt(inputRank, 1); + SmallVector newIndicesShape(inputRank, one); + newIndicesShape[dimInt] = dimSize; + newIndicesShapeInt[dimInt] = ShapedType::kDynamicSize; + Value newIndicesSizeList = rewriter.create( + op.getLoc(), ListType::get(IntType::get(op.getContext())), + newIndicesShape); + Type indicesDtype = indices.getType().cast().getDtype(); + Type newIndicesType = ValueTensorType::get( + op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype); + indices = rewriter.create(op.getLoc(), newIndicesType, + indices, newIndicesSizeList); + + // Step 2. calculate scatter indices mask + Type maskType = ValueTensorType::get( + op.getContext(), newIndicesType.cast().getSizes(), + IntegerType::get(op.getContext(), 1)); + auto shiftIndices = rewriter.create( + op.getLoc(), indices.getType(), indices, start, one); + auto stepRemainder = rewriter.create( + op.getLoc(), indices.getType(), shiftIndices, step); + Value mask = rewriter.create(op.getLoc(), maskType, + stepRemainder, zero); + auto maskStart = rewriter.create(op.getLoc(), maskType, + shiftIndices, zero); + auto maskEnd = + rewriter.create(op.getLoc(), maskType, indices, end); + mask = rewriter.create(op.getLoc(), maskType, mask, + maskStart); + mask = rewriter.create(op.getLoc(), maskType, mask, + maskEnd); + + // Step 3. make src broadcastable to self's shape + Value src = op.src(); + BaseTensorType srcTensorType = src.getType().cast(); + if (!srcTensorType.hasSizes()) + return rewriter.notifyMatchFailure(op, "src tensor must have size"); + + ArrayRef srcShape = srcTensorType.getSizes(); + int64_t srcRank = srcShape.size(); + if (srcRank != inputRank) { + if (srcRank + 1 == inputRank) { + SmallVector sizes; + sizes.append(srcShape.begin(), srcShape.end()); + sizes.insert(sizes.begin() + dimInt, 1); + Type srcType = srcTensorType.getWithSizesAndDtype( + llvm::makeArrayRef(sizes), srcTensorType.getDtype()); + src = rewriter.create(op.getLoc(), srcType, src, + op.dim()); + } else { + return rewriter.notifyMatchFailure(op, "src's rank doesn't match"); + } + } + + // Step 4. replace output = mask? src: self + rewriter.replaceOpWithNewOp(op, op.getType(), mask, + src, op.self()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAten_EmbeddingBagOp : public OpRewritePattern { @@ -3354,6 +3476,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 2cc6e984c..9b135ad24 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -240,7 +240,7 @@ class ExampleArgs: # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], - OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], + OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'], OutputType.MHLO: [], } diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 9002894c1..123561b95 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -784,7 +784,7 @@ func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) -> } // ----- -// CHECK-LABEL: func.func @torch.aten.repeat( +// CHECK-LABEL: func @torch.aten.repeat( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -810,14 +810,29 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int // ----- // CHECK-LABEL: func @torch.aten.select_scatter // CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0 -// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1 -// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1 -// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]] -// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]] -// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]] -// CHECK-NEXT: return %[[SLICE_SCATTER]] -// CHECK-NEXT: } +// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]] : !torch.int, !torch.int -> !torch.int +// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?,1],f32> +// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none +// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> +// CHECK-NEXT: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK-NEXT: %[[T5:.*]] = torch.aten.view %[[T3]], %[[T4]] : !torch.vtensor<[?],si64>, !torch.list -> !torch.vtensor<[1,?],si64> +// CHECK-NEXT: %[[T6:.*]] = torch.aten.sub.Scalar %[[T5]], %[[INT0]], %[[INT1_1]] : !torch.vtensor<[1,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> +// CHECK-NEXT: %[[T7:.*]] = torch.aten.remainder.Scalar %[[T6]], %[[INT1_0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],si64> +// CHECK-NEXT: %[[T8:.*]] = torch.aten.eq.Scalar %[[T7]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1> +// CHECK-NEXT: %[[T9:.*]] = torch.aten.ge.Scalar %[[T6]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1> +// CHECK-NEXT: %[[T10:.*]] = torch.aten.lt.Scalar %[[T5]], %[[T0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1> +// CHECK-NEXT: %[[T11:.*]] = torch.aten.bitwise_and.Tensor %[[T8]], %[[T9]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1> +// CHECK-NEXT: %[[T12:.*]] = torch.aten.bitwise_and.Tensor %[[T11]], %[[T10]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1> +// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK-NEXT: return %[[T13]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1