mirror of https://github.com/llvm/torch-mlir
Decompose torch.slice_scatter (#1622)
* Decompose torch.slice_scatter * fix compilation error * update file check * fix ci * fix i64 torch.tensor dtypepull/1637/head snapshot-20221123.666
parent
da8fdc9f96
commit
f3f2f10030
|
@ -2963,6 +2963,128 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// def slice_scatter(self, values, dim, start, end, step):
|
||||
// size = self.size(dim)
|
||||
// indices = torch.arange(size)
|
||||
// shift_indices = indices - start
|
||||
// mask = shift_indices % step == 0
|
||||
// start_mask = shift_indices >= 0
|
||||
// end_mask = shift_indices < end
|
||||
// mask = mask * start_mask
|
||||
// mask = mask * end_mask
|
||||
// sizes = list(self.size())
|
||||
// rank = len(sizes)
|
||||
// shape = [1] * rank
|
||||
// shape[dim] = size
|
||||
// mask = mask.view(shape)
|
||||
// return torch.where(mask, values, self)
|
||||
//
|
||||
class DecomposeAtenSliceScatterOp
|
||||
: public OpRewritePattern<AtenSliceScatterOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSliceScatterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
int64_t inputRank = getTensorRank(op.self());
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
}
|
||||
|
||||
auto getOptionalVal = [&](Value val, Value defVal) -> Value {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
return defVal;
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value none = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value dimSize =
|
||||
rewriter.create<AtenSizeIntOp>(op.getLoc(), op.self(), op.dim());
|
||||
|
||||
Value start = getOptionalVal(op.start(), zero);
|
||||
Value end = getOptionalVal(op.end(), dimSize);
|
||||
Value step = getOptionalVal(op.step(), one);
|
||||
// Step 0. create indices
|
||||
Type indicesType = ValueTensorType::get(
|
||||
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
|
||||
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
||||
Value indices = rewriter.create<AtenArangeOp>(
|
||||
op.getLoc(), indicesType, dimSize, none, none, none, none);
|
||||
|
||||
// Step 1. make indices broadcastable to self's shape
|
||||
SmallVector<int64_t> newIndicesShapeInt(inputRank, 1);
|
||||
SmallVector<Value> newIndicesShape(inputRank, one);
|
||||
newIndicesShape[dimInt] = dimSize;
|
||||
newIndicesShapeInt[dimInt] = ShapedType::kDynamicSize;
|
||||
Value newIndicesSizeList = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), ListType::get(IntType::get(op.getContext())),
|
||||
newIndicesShape);
|
||||
Type indicesDtype = indices.getType().cast<ValueTensorType>().getDtype();
|
||||
Type newIndicesType = ValueTensorType::get(
|
||||
op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype);
|
||||
indices = rewriter.create<AtenViewOp>(op.getLoc(), newIndicesType,
|
||||
indices, newIndicesSizeList);
|
||||
|
||||
// Step 2. calculate scatter indices mask
|
||||
Type maskType = ValueTensorType::get(
|
||||
op.getContext(), newIndicesType.cast<ValueTensorType>().getSizes(),
|
||||
IntegerType::get(op.getContext(), 1));
|
||||
auto shiftIndices = rewriter.create<AtenSubScalarOp>(
|
||||
op.getLoc(), indices.getType(), indices, start, one);
|
||||
auto stepRemainder = rewriter.create<AtenRemainderScalarOp>(
|
||||
op.getLoc(), indices.getType(), shiftIndices, step);
|
||||
Value mask = rewriter.create<AtenEqScalarOp>(op.getLoc(), maskType,
|
||||
stepRemainder, zero);
|
||||
auto maskStart = rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType,
|
||||
shiftIndices, zero);
|
||||
auto maskEnd =
|
||||
rewriter.create<AtenLtScalarOp>(op.getLoc(), maskType, indices, end);
|
||||
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
|
||||
maskStart);
|
||||
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
|
||||
maskEnd);
|
||||
|
||||
// Step 3. make src broadcastable to self's shape
|
||||
Value src = op.src();
|
||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||
if (!srcTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
||||
|
||||
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
||||
int64_t srcRank = srcShape.size();
|
||||
if (srcRank != inputRank) {
|
||||
if (srcRank + 1 == inputRank) {
|
||||
SmallVector<int64_t> sizes;
|
||||
sizes.append(srcShape.begin(), srcShape.end());
|
||||
sizes.insert(sizes.begin() + dimInt, 1);
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(sizes), srcTensorType.getDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(op.getLoc(), srcType, src,
|
||||
op.dim());
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "src's rank doesn't match");
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4. replace output = mask? src: self
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), mask,
|
||||
src, op.self());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAten_EmbeddingBagOp
|
||||
: public OpRewritePattern<Aten_EmbeddingBagOp> {
|
||||
|
@ -3354,6 +3476,8 @@ public:
|
|||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
patterns.add<DecomposeAtenSelectScatterOp>(context);
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
patterns.add<DecomposeAtenSliceScatterOp>(context);
|
||||
target.addIllegalOp<AtenSliceScatterOp>();
|
||||
patterns.add<DecomposeAtenVarDimOp>(context);
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
patterns.add<DecomposeAtenVarCorrectionOp>(context);
|
||||
|
|
|
@ -240,7 +240,7 @@ class ExampleArgs:
|
|||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
||||
|
|
|
@ -784,7 +784,7 @@ func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) ->
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.repeat(
|
||||
// CHECK-LABEL: func @torch.aten.repeat(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -810,14 +810,29 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
|
|||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.select_scatter
|
||||
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
|
||||
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
|
||||
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
|
||||
// CHECK-NEXT: return %[[SLICE_SCATTER]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||
// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
// CHECK-NEXT: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-NEXT: %[[T5:.*]] = torch.aten.view %[[T3]], %[[T4]] : !torch.vtensor<[?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T6:.*]] = torch.aten.sub.Scalar %[[T5]], %[[INT0]], %[[INT1_1]] : !torch.vtensor<[1,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T7:.*]] = torch.aten.remainder.Scalar %[[T6]], %[[INT1_0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T8:.*]] = torch.aten.eq.Scalar %[[T7]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T9:.*]] = torch.aten.ge.Scalar %[[T6]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T10:.*]] = torch.aten.lt.Scalar %[[T5]], %[[T0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T11:.*]] = torch.aten.bitwise_and.Tensor %[[T8]], %[[T9]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T12:.*]] = torch.aten.bitwise_and.Tensor %[[T11]], %[[T10]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK-NEXT: return %[[T13]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
|
Loading…
Reference in New Issue