mirror of https://github.com/llvm/torch-mlir
[Torch] support aten.column_stack (#3867)
parent
95f77817b9
commit
896f66c688
|
@ -14700,6 +14700,29 @@ def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenColumnStackOp : Torch_Op<"aten.column_stack", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::column_stack : (Tensor[]) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchListOfTensorType:$tensors
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenColumnStackOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenColumnStackOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -10886,6 +10886,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %5 : !torch.list<int>\n"
|
" return %5 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.column_stack\"(%arg0: !torch.list<list<int>>) -> !torch.list<int> {\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
|
||||||
|
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
|
||||||
|
" torch.prim.Loop %1, %true, init() {\n"
|
||||||
|
" ^bb0(%arg1: !torch.int):\n"
|
||||||
|
" %3 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" %4 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %6 = torch.prim.If %5 -> (!torch.list<int>) {\n"
|
||||||
|
" %8 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %8 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %8 = torch.aten.len.t %3 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %9 -> () {\n"
|
||||||
|
" %10 = torch.aten.append.t %3, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" %7 = torch.aten.append.t %0, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
|
||||||
|
" torch.prim.Loop.condition %true, iter()\n"
|
||||||
|
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||||
|
" %2 = call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
@ -15621,6 +15652,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %5 : !torch.int\n"
|
" return %5 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.column_stack\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.int {\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
|
||||||
|
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||||
|
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %3 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||||
|
" torch.prim.Loop %4, %true, init() {\n"
|
||||||
|
" ^bb0(%arg1: !torch.int):\n"
|
||||||
|
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
|
||||||
|
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.Loop.condition %true, iter()\n"
|
||||||
|
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||||
|
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %5 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
|
|
@ -4192,6 +4192,68 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose `aten.column_stack` into `aten.reshape` and `aten.cat`.
|
||||||
|
// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L2822
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenColumnStackOp : public OpRewritePattern<AtenColumnStackOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenColumnStackOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
SmallVector<Value> tensors;
|
||||||
|
if (!getListConstructElements(op.getTensors(), tensors))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the tensor list is not from list construct");
|
||||||
|
|
||||||
|
for (auto tensor : tensors) {
|
||||||
|
auto tTy = dyn_cast<BaseTensorType>(tensor.getType());
|
||||||
|
if (!tTy || !tTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: one tensor does not have known sizes");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> tensors2d;
|
||||||
|
for (auto tensor : tensors) {
|
||||||
|
auto tTy = dyn_cast<BaseTensorType>(tensor.getType());
|
||||||
|
SmallVector<int64_t> tSizes(tTy.getSizes());
|
||||||
|
if (tSizes.size() <= 1) {
|
||||||
|
if (tSizes.size() == 0) {
|
||||||
|
tSizes.push_back(1);
|
||||||
|
}
|
||||||
|
tSizes.push_back(1);
|
||||||
|
auto newTy = tTy.getWithSizesAndDtype(tSizes, tTy.getDtype());
|
||||||
|
SmallVector<Value> newShapeList;
|
||||||
|
for (auto tSize : tSizes) {
|
||||||
|
newShapeList.push_back(rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(tSize)));
|
||||||
|
}
|
||||||
|
auto newShape = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(rewriter.getType<IntType>()),
|
||||||
|
newShapeList);
|
||||||
|
Value tensor2d =
|
||||||
|
rewriter.create<AtenReshapeOp>(loc, newTy, tensor, newShape);
|
||||||
|
tensors2d.push_back(tensor2d);
|
||||||
|
} else {
|
||||||
|
tensors2d.push_back(tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto elemType = cast<BaseTensorType>(tensors2d[0].getType())
|
||||||
|
.getWithSizesAndDtype(std::nullopt, nullptr);
|
||||||
|
Value newTensors = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(elemType), tensors2d);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<AtenCatOp>(
|
||||||
|
op, op.getType(), newTensors,
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.roll into aten.slice and aten.cat ops.
|
// Decompose aten.roll into aten.slice and aten.cat ops.
|
||||||
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -10554,6 +10616,7 @@ public:
|
||||||
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHstackOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHstackOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenColumnStackOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
|
||||||
|
|
|
@ -382,6 +382,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenZerosLikeOp>();
|
target.addIllegalOp<AtenZerosLikeOp>();
|
||||||
target.addIllegalOp<AtenStackOp>();
|
target.addIllegalOp<AtenStackOp>();
|
||||||
target.addIllegalOp<AtenHstackOp>();
|
target.addIllegalOp<AtenHstackOp>();
|
||||||
|
target.addIllegalOp<AtenColumnStackOp>();
|
||||||
target.addIllegalOp<AtenRollOp>();
|
target.addIllegalOp<AtenRollOp>();
|
||||||
target.addIllegalOp<AtenRepeatOp>();
|
target.addIllegalOp<AtenRepeatOp>();
|
||||||
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
|
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
|
||||||
|
|
|
@ -2866,6 +2866,9 @@ ONNX_XFAIL_SET = {
|
||||||
"CollapsePartialDynamicModule_basic",
|
"CollapsePartialDynamicModule_basic",
|
||||||
"CollapseRank1DynamicModule_basic",
|
"CollapseRank1DynamicModule_basic",
|
||||||
"CollapseStaticModule_basic",
|
"CollapseStaticModule_basic",
|
||||||
|
"ColumnStackBasicIntModule_basic",
|
||||||
|
"ColumnStack1dModule_basic",
|
||||||
|
"ColumnStack0dModule_basic",
|
||||||
"ConstantBoolParameterModule_basic",
|
"ConstantBoolParameterModule_basic",
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
|
|
|
@ -2279,6 +2279,20 @@ def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]:
|
||||||
|
|
||||||
return upstream_shape_functions.cat(tensors_atleast1d, dim=1)
|
return upstream_shape_functions.cat(tensors_atleast1d, dim=1)
|
||||||
|
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case.
|
||||||
|
])
|
||||||
|
def aten〇column_stack〡shape(tensors: List[List[int]]) -> List[int]:
|
||||||
|
tensors2d: List[List[int]] = []
|
||||||
|
for tensor in tensors:
|
||||||
|
if len(tensor) == 0:
|
||||||
|
tensor = [1, 1]
|
||||||
|
elif len(tensor) == 1:
|
||||||
|
tensor.append(1)
|
||||||
|
tensors2d.append(tensor)
|
||||||
|
|
||||||
|
return upstream_shape_functions.cat(tensors2d, dim=1)
|
||||||
|
|
||||||
def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
|
def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -5560,6 +5574,23 @@ def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int:
|
||||||
|
|
||||||
return promote_dtypes(ranks, dtypes)
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
[Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]),
|
||||||
|
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
|
||||||
|
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
|
||||||
|
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
|
||||||
|
NonZeroDTensorWithDtype(torch.complex64)])])
|
||||||
|
def aten〇column_stack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int:
|
||||||
|
ranks: List[Optional[int]] = []
|
||||||
|
dtypes: List[int] = []
|
||||||
|
assert len(tensors_rank_dtype) != 0
|
||||||
|
for tensor_rank_dtype in tensors_rank_dtype:
|
||||||
|
tensor_rank, tensor_dtype = tensor_rank_dtype
|
||||||
|
ranks.append(tensor_rank)
|
||||||
|
dtypes.append(tensor_dtype)
|
||||||
|
|
||||||
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
|
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
|
||||||
TensorOfShape(1, dtype=torch.int32)]),])
|
TensorOfShape(1, dtype=torch.int32)]),])
|
||||||
|
|
|
@ -1053,6 +1053,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
)
|
)
|
||||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::hstack : (Tensor[]) -> (Tensor)")
|
emit("aten::hstack : (Tensor[]) -> (Tensor)")
|
||||||
|
emit("aten::column_stack : (Tensor[]) -> (Tensor)")
|
||||||
emit("aten::append.t : (t[], t) -> (t[])")
|
emit("aten::append.t : (t[], t) -> (t[])")
|
||||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||||
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
||||||
|
|
|
@ -1409,6 +1409,83 @@ def HstackBasicComplexModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnStackBasicIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3, 4], torch.bool, True),
|
||||||
|
([2, 3, 4], torch.int32, True),
|
||||||
|
([2, 3, 4], torch.int64, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x, y, z):
|
||||||
|
return torch.ops.aten.column_stack([x, y, z])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ColumnStackBasicIntModule())
|
||||||
|
def ColumnStackBasicIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(2, 3, 4, low=0, high=2).bool(),
|
||||||
|
tu.randint(2, 3, 4, low=0, high=100).int(),
|
||||||
|
tu.randint(2, 3, 4, low=0, high=100).long(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnStack1dModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([4], torch.float32, True),
|
||||||
|
([4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.column_stack([x, y])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ColumnStack1dModule())
|
||||||
|
def ColumnStack1dModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4), tu.rand(4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnStack0dModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.column_stack([x, y])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ColumnStack0dModule())
|
||||||
|
def ColumnStack0dModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor(4.0), torch.tensor(1.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GatherModule(torch.nn.Module):
|
class GatherModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue