[torch-dialect] emit aten.index_add and decompose it to scatter.add op

pull/3085/head
wujiawei.aml 2024-05-08 22:36:24 +08:00
parent 28193fd985
commit 4e5577ad88
6 changed files with 158 additions and 3 deletions

View File

@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [
}]; }];
} }
def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAddOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_IntType:$dim,
Torch_NonValueTensorType:$index,
Torch_NonValueTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAdd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -5621,6 +5621,75 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten.index_add` op into `aten.index_put`
class DecomposeAtenIndexAddOp : public OpRewritePattern<AtenIndexAddOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexAddOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value src = op.getSource();
Value input = op.getSelf();
Value index = op.getIndex();
Value alpha = op.getAlpha();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op,
"dim of index_add must be a constant");
}
std::optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
int64_t inputRank = static_cast<int64_t>(*maybeInputRank);
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "index dim is not a valid dim");
}
auto resType = op.getType().cast<BaseTensorType>();
auto srcType = src.getType().cast<BaseTensorType>();
auto indexType = index.getType().cast<BaseTensorType>();
if (!indexType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "index should have dtype");
}
auto indexDtype = indexType.getDtype();
// calculate src * alpha first.
Value newSrc =
rewriter.create<Torch::AtenMulScalarOp>(loc, srcType, src, alpha);
// broadcast index to have the same shape as src.
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
for (int64_t i = dim + 1; i < inputRank; ++i) {
index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne);
}
SmallVector<int64_t> bcastShape;
SmallVector<Value> bcastShapeValue;
computeBroadcastShape(rewriter, loc, index, src, bcastShape,
bcastShapeValue);
Type bcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(bcastShape), indexDtype);
Value indexBcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
bcastShapeValue);
index = rewriter.create<Torch::AtenBroadcastToOp>(loc, bcastType, index,
indexBcastShapeTorchList);
rewriter.replaceOpWithNewOp<Torch::AtenScatterAddOp>(op, resType, input,
op.getDim(), index, newSrc);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> { class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@ -8021,6 +8090,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexAddOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);

View File

@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMishOp>(); target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>(); target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>(); target.addIllegalOp<AtenNewFullOp>();
target.addIllegalOp<AtenIndexAddOp>();
target.addIllegalOp<AtenExpandAsOp>(); target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>(); target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>(); target.addIllegalOp<AtenDropoutOp>();

View File

@ -1607,15 +1607,18 @@ def atenscattervalue〡shape(self: List[int], dim: int, index: List[int],
def atenindex_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: def atenindex_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index) return upstream_shape_functions.index_select(self, dim, index)
def atenscatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenindex_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)
def atenindex_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]: def atenindex_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenindex_puthacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]: def atenindex_puthacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenembedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
def atenembedding_bagpadding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: def atenembedding_bagpadding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
return _embedding_bag_helper(weight, indices, offsets, include_last_offset, return _embedding_bag_helper(weight, indices, offsets, include_last_offset,
mode, per_sample_weights, padding_idx) mode, per_sample_weights, padding_idx)
@ -2534,6 +2537,16 @@ def atenindex_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype return self_dtype
@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def atenindex_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def atenscatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64))) @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64)))
def atenindex_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: def atenindex_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -516,6 +516,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants( emit_with_mutating_variants(
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
) )