mirror of https://github.com/llvm/torch-mlir
[TM_TENSOR] Add `aten.scatter.[src|value]` op
This commit adds support of `aten.scatter.src` and `aten.scatter.value` ops. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/2183/head snapshot-20230529.853
parent
b9d29dc055
commit
552887783a
|
@ -263,6 +263,11 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
"ScatterValueFloatModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
"ScatterValueIntModule_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
|
@ -1247,4 +1252,6 @@ LTC_XFAIL_SET = {
|
|||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ScatterValueFloatModule_basic",
|
||||
"ScatterValueIntModule_basic",
|
||||
}
|
||||
|
|
|
@ -5073,6 +5073,108 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatter_SrcOp : Torch_Op<"aten.scatter_.src", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatter_SrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatter_SrcOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatter_ValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatter_ValueOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -8899,58 +9001,6 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -299,6 +299,55 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
|||
return SmallVector<Value>(sortOp.getResults());
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.getSelf();
|
||||
Value index = adaptor.getIndex();
|
||||
Value src = adaptor.getSrc();
|
||||
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
|
||||
RankedTensorType srcType = src.getType().cast<RankedTensorType>();
|
||||
if (selfType.getRank() != indexType.getRank() ||
|
||||
indexType.getRank() != srcType.getRank())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"'self', 'index' and 'src' should all"
|
||||
"have the same number of dimensions.");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: dim is not constant");
|
||||
|
||||
// Get the inputs reformatted for the TMScatterOp
|
||||
auto [indices, updates] =
|
||||
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(rewriter, index,
|
||||
src, dim);
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, self,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value updatesElement,
|
||||
Value inputElement) {
|
||||
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
||||
});
|
||||
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||
// non-negative ints.
|
||||
|
@ -1606,6 +1655,9 @@ public:
|
|||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||
context);
|
||||
|
||||
target.addIllegalOp<AtenScatterSrcOp>();
|
||||
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -7403,6 +7403,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scatter.src\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -8371,6 +8377,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.src\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -4399,6 +4399,50 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.scatter.value` op into `aten.scatter.src` op.
|
||||
class DecomposeAtenScatterValueOp
|
||||
: public OpRewritePattern<AtenScatterValueOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenScatterValueOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value self = op.getSelf();
|
||||
Value index = op.getIndex();
|
||||
std::optional<unsigned> maybeIndexRank = getTensorRank(index);
|
||||
if (!maybeIndexRank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected index tensor to have a rank");
|
||||
}
|
||||
unsigned indexRank = *maybeIndexRank;
|
||||
SmallVector<Value> sizes;
|
||||
for (int64_t i = 0; i < indexRank; ++i) {
|
||||
Value dim =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
||||
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, index, /*dim=*/dim));
|
||||
}
|
||||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), sizes);
|
||||
|
||||
auto selfType = self.getType().cast<BaseTensorType>();
|
||||
auto indexType = index.getType().cast<BaseTensorType>();
|
||||
BaseTensorType srcType =
|
||||
selfType
|
||||
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
||||
selfType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
Value src =
|
||||
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
||||
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
||||
op.getDim(), index, src);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -4563,6 +4607,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
||||
target.addIllegalOp<AtenVarMeanDimOp>();
|
||||
target.addIllegalOp<AtenTopkOp>();
|
||||
target.addIllegalOp<AtenScatterValueOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
|
|
|
@ -886,6 +886,12 @@ def aten〇select_scatter〡shape(self: List[int], src: List[int], dim: int, ind
|
|||
def aten〇scatter_reduce〇two〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.index_select(self, dim, index)
|
||||
|
||||
|
@ -1716,6 +1722,18 @@ def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dty
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -410,6 +410,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
||||
)
|
||||
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
|
@ -559,8 +561,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
|
|
|
@ -821,6 +821,102 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterSrcStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([10, 8, 6], torch.float32, True),
|
||||
([2, 4, 3], torch.int64, True),
|
||||
([5, 8, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, src):
|
||||
return torch.ops.aten.scatter(input, 0, index, src)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScatterSrcStaticModule())
|
||||
def ScatterSrcStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterSrcModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, src):
|
||||
return torch.ops.aten.scatter(input, 1, index, src)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScatterSrcModule())
|
||||
def ScatterSrcModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||
tu.rand(3, 4, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterValueFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.scatter(input, 2, index, float(value))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScatterValueFloatModule())
|
||||
def ScatterValueFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||
tu.rand().double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterValueIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.scatter(input, 0, index, int(value))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScatterValueIntModule())
|
||||
def ScatterValueIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||
tu.randint(high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterReduceFloatModule(torch.nn.Module):
|
||||
include_self: bool
|
||||
reduce_type: str
|
||||
|
|
Loading…
Reference in New Issue