mirror of https://github.com/llvm/torch-mlir
[torch-dialect] emit aten.index_add and decompose it to scatter.add op
parent
28193fd985
commit
4e5577ad88
|
@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$source,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenIndexAddOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
Torch_NonValueTensorType:$index,
|
||||
Torch_NonValueTensorType:$source,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenIndexAdd_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -5621,6 +5621,75 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.index_add` op into `aten.index_put`
|
||||
class DecomposeAtenIndexAddOp : public OpRewritePattern<AtenIndexAddOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIndexAddOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value src = op.getSource();
|
||||
Value input = op.getSelf();
|
||||
Value index = op.getIndex();
|
||||
Value alpha = op.getAlpha();
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dim of index_add must be a constant");
|
||||
}
|
||||
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
||||
if (!maybeInputRank) {
|
||||
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||
}
|
||||
int64_t inputRank = static_cast<int64_t>(*maybeInputRank);
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "index dim is not a valid dim");
|
||||
}
|
||||
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto srcType = src.getType().cast<BaseTensorType>();
|
||||
auto indexType = index.getType().cast<BaseTensorType>();
|
||||
if (!indexType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "index should have dtype");
|
||||
}
|
||||
auto indexDtype = indexType.getDtype();
|
||||
|
||||
// calculate src * alpha first.
|
||||
Value newSrc =
|
||||
rewriter.create<Torch::AtenMulScalarOp>(loc, srcType, src, alpha);
|
||||
|
||||
// broadcast index to have the same shape as src.
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
for (int64_t i = dim + 1; i < inputRank; ++i) {
|
||||
index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> bcastShape;
|
||||
SmallVector<Value> bcastShapeValue;
|
||||
computeBroadcastShape(rewriter, loc, index, src, bcastShape,
|
||||
bcastShapeValue);
|
||||
|
||||
Type bcastType = ValueTensorType::get(
|
||||
op.getContext(), llvm::ArrayRef(bcastShape), indexDtype);
|
||||
|
||||
Value indexBcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
bcastShapeValue);
|
||||
|
||||
index = rewriter.create<Torch::AtenBroadcastToOp>(loc, bcastType, index,
|
||||
indexBcastShapeTorchList);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenScatterAddOp>(op, resType, input,
|
||||
op.getDim(), index, newSrc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -8021,6 +8090,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexAddOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
||||
|
|
|
@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenMishOp>();
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
target.addIllegalOp<AtenNewFullOp>();
|
||||
target.addIllegalOp<AtenIndexAddOp>();
|
||||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
|
|
|
@ -1607,15 +1607,18 @@ def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int],
|
|||
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.index_select(self, dim, index)
|
||||
|
||||
def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇index_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇index_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇index_put〇hacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
return _embedding_bag_helper(weight, indices, offsets, include_last_offset,
|
||||
mode, per_sample_weights, padding_idx)
|
||||
|
@ -2534,6 +2537,16 @@ def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇index_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64)))
|
||||
def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -516,6 +516,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue