mirror of https://github.com/llvm/torch-mlir
[TM_TENSOR] Add `aten.scatter.[src|value]` op
This commit adds support of `aten.scatter.src` and `aten.scatter.value` ops. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/2183/head snapshot-20230529.853
parent
b9d29dc055
commit
552887783a
|
@ -263,6 +263,11 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
|
||||||
"UnbindIntListUnpack_Module_basic",
|
"UnbindIntListUnpack_Module_basic",
|
||||||
"UnbindIntGetItem_Module_basic",
|
"UnbindIntGetItem_Module_basic",
|
||||||
|
|
||||||
|
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||||
|
"ScatterValueFloatModule_basic",
|
||||||
|
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||||
|
"ScatterValueIntModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCHDYNAMO_CRASHING_SET = {
|
TORCHDYNAMO_CRASHING_SET = {
|
||||||
|
@ -1247,4 +1252,6 @@ LTC_XFAIL_SET = {
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
|
"ScatterValueFloatModule_basic",
|
||||||
|
"ScatterValueIntModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -5073,6 +5073,108 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchTensorType:$src
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatter_SrcOp : Torch_Op<"aten.scatter_.src", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter_.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchTensorType:$src
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatter_SrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatter_SrcOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchScalarType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter_.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchScalarType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatter_ValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatter_ValueOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -8899,58 +9001,6 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self,
|
|
||||||
Torch_IntType:$dim,
|
|
||||||
AnyTorchTensorType:$index,
|
|
||||||
AnyTorchTensorType:$src
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
|
||||||
}
|
|
||||||
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 4, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self,
|
|
||||||
Torch_IntType:$dim,
|
|
||||||
AnyTorchTensorType:$index,
|
|
||||||
AnyTorchScalarType:$value
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
|
||||||
}
|
|
||||||
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 4, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -299,6 +299,55 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||||
return SmallVector<Value>(sortOp.getResults());
|
return SmallVector<Value>(sortOp.getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
Value index = adaptor.getIndex();
|
||||||
|
Value src = adaptor.getSrc();
|
||||||
|
|
||||||
|
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||||
|
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
|
||||||
|
RankedTensorType srcType = src.getType().cast<RankedTensorType>();
|
||||||
|
if (selfType.getRank() != indexType.getRank() ||
|
||||||
|
indexType.getRank() != srcType.getRank())
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"'self', 'index' and 'src' should all"
|
||||||
|
"have the same number of dimensions.");
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"unimplemented: dim is not constant");
|
||||||
|
|
||||||
|
// Get the inputs reformatted for the TMScatterOp
|
||||||
|
auto [indices, updates] =
|
||||||
|
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(rewriter, index,
|
||||||
|
src, dim);
|
||||||
|
Value scatterOp = createTMTensorScatterOp(
|
||||||
|
rewriter, loc, updates, indices, self,
|
||||||
|
/*uniqueIndices=*/false,
|
||||||
|
[&](OpBuilder &b, Location loc, Value updatesElement,
|
||||||
|
Value inputElement) {
|
||||||
|
b.create<TMTensor::YieldOp>(loc, updatesElement);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||||
// non-negative ints.
|
// non-negative ints.
|
||||||
|
@ -1606,6 +1655,9 @@ public:
|
||||||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
|
||||||
|
target.addIllegalOp<AtenScatterSrcOp>();
|
||||||
|
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -7403,6 +7403,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.scatter.src\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -8371,6 +8377,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.src\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -4399,6 +4399,50 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `aten.scatter.value` op into `aten.scatter.src` op.
|
||||||
|
class DecomposeAtenScatterValueOp
|
||||||
|
: public OpRewritePattern<AtenScatterValueOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenScatterValueOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
Value index = op.getIndex();
|
||||||
|
std::optional<unsigned> maybeIndexRank = getTensorRank(index);
|
||||||
|
if (!maybeIndexRank) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected index tensor to have a rank");
|
||||||
|
}
|
||||||
|
unsigned indexRank = *maybeIndexRank;
|
||||||
|
SmallVector<Value> sizes;
|
||||||
|
for (int64_t i = 0; i < indexRank; ++i) {
|
||||||
|
Value dim =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
||||||
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, index, /*dim=*/dim));
|
||||||
|
}
|
||||||
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, ListType::get(IntType::get(context)), sizes);
|
||||||
|
|
||||||
|
auto selfType = self.getType().cast<BaseTensorType>();
|
||||||
|
auto indexType = index.getType().cast<BaseTensorType>();
|
||||||
|
BaseTensorType srcType =
|
||||||
|
selfType
|
||||||
|
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
||||||
|
selfType.getOptionalDtype())
|
||||||
|
.cast<BaseTensorType>();
|
||||||
|
Value src =
|
||||||
|
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
||||||
|
op.getDim(), index, src);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -4563,6 +4607,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
|
|
|
@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
||||||
target.addIllegalOp<AtenVarMeanDimOp>();
|
target.addIllegalOp<AtenVarMeanDimOp>();
|
||||||
target.addIllegalOp<AtenTopkOp>();
|
target.addIllegalOp<AtenTopkOp>();
|
||||||
|
target.addIllegalOp<AtenScatterValueOp>();
|
||||||
for (auto &opName : backendLegalOpsSet) {
|
for (auto &opName : backendLegalOpsSet) {
|
||||||
target.addLegalOp(
|
target.addLegalOp(
|
||||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||||
|
|
|
@ -886,6 +886,12 @@ def aten〇select_scatter〡shape(self: List[int], src: List[int], dim: int, ind
|
||||||
def aten〇scatter_reduce〇two〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]:
|
def aten〇scatter_reduce〇two〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
|
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.index_select(self, dim, index)
|
return upstream_shape_functions.index_select(self, dim, index)
|
||||||
|
|
||||||
|
@ -1716,6 +1722,18 @@ def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dty
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
|
||||||
|
def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES])
|
||||||
|
def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -410,6 +410,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||||
|
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||||
|
@ -559,8 +561,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
|
||||||
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
|
||||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -821,6 +821,102 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ScatterSrcStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([10, 8, 6], torch.float32, True),
|
||||||
|
([2, 4, 3], torch.int64, True),
|
||||||
|
([5, 8, 6], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, src):
|
||||||
|
return torch.ops.aten.scatter(input, 0, index, src)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScatterSrcStaticModule())
|
||||||
|
def ScatterSrcStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||||
|
tu.rand(5, 8, 6))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ScatterSrcModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, src):
|
||||||
|
return torch.ops.aten.scatter(input, 1, index, src)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScatterSrcModule())
|
||||||
|
def ScatterSrcModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||||
|
tu.rand(3, 4, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ScatterValueFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten.scatter(input, 2, index, float(value))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScatterValueFloatModule())
|
||||||
|
def ScatterValueFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||||
|
tu.rand().double())
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ScatterValueIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten.scatter(input, 0, index, int(value))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScatterValueIntModule())
|
||||||
|
def ScatterValueIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
|
||||||
|
tu.randint(high=10))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ScatterReduceFloatModule(torch.nn.Module):
|
class ScatterReduceFloatModule(torch.nn.Module):
|
||||||
include_self: bool
|
include_self: bool
|
||||||
reduce_type: str
|
reduce_type: str
|
||||||
|
|
Loading…
Reference in New Issue