[Torch] Decompose torch.take op into AtenFlattenUsingInts and AtenSelectIndex

Decompose op inside DecomposeComplexOps.cpp

Add tests into slice_like.py
pull/3761/head
Bratislav Filipovic 2024-10-02 13:46:09 +02:00
parent 67732883fa
commit c7ef459f33
7 changed files with 148 additions and 15 deletions

View File

@ -13148,6 +13148,30 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
}]; }];
} }
def Torch_AtenTakeOp : Torch_Op<"aten.take", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::take : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$index
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTakeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenTakeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -6591,6 +6591,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n" " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n" " return %3 : !torch.tuple<int, int, int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.take\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.take\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: indexes must be integer types\"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list<int>) -> !torch.tuple<list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list<int>) -> !torch.tuple<list<int>, list<int>> {\n"
" %int-2 = torch.constant.int -2\n" " %int-2 = torch.constant.int -2\n"
" %int-1 = torch.constant.int -1\n" " %int-1 = torch.constant.int -1\n"
@ -11238,21 +11270,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n" " return %3 : !torch.int\n"
" }\n" " }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -4113,6 +4113,38 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenTakeOp : public OpRewritePattern<AtenTakeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTakeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value index = op.getIndex();
auto selfTy = cast<BaseTensorType>(self.getType());
auto resType = cast<BaseTensorType>(op.getType());
int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes
auto flattenType = selfTy.getWithSizesAndDtype(
/*optionalSizes=*/{selfNumel}, resType.getDtype());
Value constMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value constZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value flattenSelf = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenType, self, constZero, constMinusOne);
Value result = rewriter.create<Torch::AtenIndexSelectOp>(
loc, resType, flattenSelf, constZero, index);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// decompose aten.repeat_interleave.self_int into following ops: // decompose aten.repeat_interleave.self_int into following ops:
// aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape // aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape
namespace { namespace {
@ -9660,6 +9692,7 @@ public:
legalOpsSet.clear(); legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end()); legalOpsSet.insert(legalOps.begin(), legalOps.end());
addPatternIfTargetOpIsIllegal<DecomposeAtenTakeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>( addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
patterns); patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

View File

@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLinalgSlogdetOp>(); target.addIllegalOp<AtenLinalgSlogdetOp>();
target.addIllegalOp<AtenPixelShuffleOp>(); target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>(); target.addIllegalOp<AtenTOp>();
target.addIllegalOp<AtenTakeOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>(); target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) { target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
std::optional<unsigned> lhsRank = getTensorRank(op.getSelf()); std::optional<unsigned> lhsRank = getTensorRank(op.getSelf());

View File

@ -257,6 +257,16 @@ def aten_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List
def aten_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]: def aten_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]:
return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1]) return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1])
def atentake〡shape(self: List[int], index: List[int]) -> List[int]:
return index
def atentake〡dtype(self_rank_dtype: Tuple[int, int], index_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
index_rank, index_dtype = index_rank_dtype
assert is_integer_dtype(index_dtype), "indexes must be integer types"
return self_dtype
def atenlinalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: def atenlinalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
assert len(A) == 2 or len(A) == 3 assert len(A) == 2 or len(A) == 3
assert A[-1] == A[-2] assert A[-1] == A[-2]

View File

@ -964,6 +964,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
) )
emit("aten::take : (Tensor, Tensor) -> (Tensor)")
# Functionalization ops # Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias_copy : (Tensor) -> (Tensor)")

View File

@ -1121,3 +1121,50 @@ class TensorSplitSections_ListUnpackModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule()) @register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils): def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5)) module.forward(tu.rand(2, 5))
# ==============================================================================
class TakeModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4), torch.float32, True], [(4,), torch.int64, True]])
def forward(self, input, index):
return torch.take(input, index)
@register_test_case(module_factory=lambda: TakeModule())
def TakeModule_F32(module, tu: TestUtils):
A = tu.rand(4, 4).to(dtype=torch.float32)
index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64)
module.forward(A, index)
class TakeBatchModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4, 4), torch.float32, True], [(4,), torch.int64, True]])
def forward(self, input, index):
return torch.take(input, index)
@register_test_case(module_factory=lambda: TakeBatchModule())
def TakeModuleBatched_F32(module, tu: TestUtils):
A = tu.rand(4, 4, 4).to(dtype=torch.float32)
index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64)
module.forward(A, index)
class TakeDynamicModule(torch.nn.Module):
@export
@annotate_args(
[None, [(-1, -1, -1), torch.float32, True], [(4,), torch.int64, True]]
)
def forward(self, input, index):
return torch.take(input, index)
@register_test_case(module_factory=lambda: TakeDynamicModule())
def TakeModuleDynamic_F32(module, tu: TestUtils):
A = tu.rand(4, 4, 8).to(dtype=torch.float32)
index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64)
module.forward(A, index)