diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c5b491197..f169993a1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14700,6 +14700,29 @@ def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [ }]; } +def Torch_AtenColumnStackOp : Torch_Op<"aten.column_stack", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::column_stack : (Tensor[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenColumnStackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenColumnStackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ AllowsTypeRefinement ]> { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 7a0a24a28..560b6a821 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10886,6 +10886,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %5 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %3 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" +" %4 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %8 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %8 : !torch.list\n" +" } else {\n" +" %8 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" %10 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" %7 = torch.aten.append.t %0, %6 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list>, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -15621,6 +15652,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9db8a6949..445a354d4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4192,6 +4192,68 @@ public: }; } // namespace +// Decompose `aten.column_stack` into `aten.reshape` and `aten.cat`. +// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L2822 +namespace { +class DecomposeAtenColumnStackOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenColumnStackOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the tensor list is not from list construct"); + + for (auto tensor : tensors) { + auto tTy = dyn_cast(tensor.getType()); + if (!tTy || !tTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: one tensor does not have known sizes"); + } + + SmallVector tensors2d; + for (auto tensor : tensors) { + auto tTy = dyn_cast(tensor.getType()); + SmallVector tSizes(tTy.getSizes()); + if (tSizes.size() <= 1) { + if (tSizes.size() == 0) { + tSizes.push_back(1); + } + tSizes.push_back(1); + auto newTy = tTy.getWithSizesAndDtype(tSizes, tTy.getDtype()); + SmallVector newShapeList; + for (auto tSize : tSizes) { + newShapeList.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(tSize))); + } + auto newShape = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + newShapeList); + Value tensor2d = + rewriter.create(loc, newTy, tensor, newShape); + tensors2d.push_back(tensor2d); + } else { + tensors2d.push_back(tensor); + } + } + + auto elemType = cast(tensors2d[0].getType()) + .getWithSizesAndDtype(std::nullopt, nullptr); + Value newTensors = rewriter.create( + loc, Torch::ListType::get(elemType), tensors2d); + + rewriter.replaceOpWithNewOp( + op, op.getType(), newTensors, + rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + + return success(); + } +}; +} // namespace + // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { @@ -10554,6 +10616,7 @@ public: DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 4bca74470..4dd855be4 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -382,6 +382,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 47a095683..18adad513 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2866,6 +2866,9 @@ ONNX_XFAIL_SET = { "CollapsePartialDynamicModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", + "ColumnStackBasicIntModule_basic", + "ColumnStack1dModule_basic", + "ColumnStack0dModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e78b3d49d..8a9e7755e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2279,6 +2279,20 @@ def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: return upstream_shape_functions.cat(tensors_atleast1d, dim=1) +@check_shape_function([ + Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case. +]) +def aten〇column_stack〡shape(tensors: List[List[int]]) -> List[int]: + tensors2d: List[List[int]] = [] + for tensor in tensors: + if len(tensor) == 0: + tensor = [1, 1] + elif len(tensor) == 1: + tensor.append(1) + tensors2d.append(tensor) + + return upstream_shape_functions.cat(tensors2d, dim=1) + def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self @@ -5560,6 +5574,23 @@ def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: return promote_dtypes(ranks, dtypes) +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇column_stack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + + return promote_dtypes(ranks, dtypes) + @check_dtype_function( [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 371b73347..0913b2c67 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1053,6 +1053,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::hstack : (Tensor[]) -> (Tensor)") + emit("aten::column_stack : (Tensor[]) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 5aa22ce3b..94f1538db 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1409,6 +1409,83 @@ def HstackBasicComplexModule_basic(module, tu: TestUtils): # ============================================================================== +class ColumnStackBasicIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4], torch.bool, True), + ([2, 3, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ] + ) + def forward(self, x, y, z): + return torch.ops.aten.column_stack([x, y, z]) + + +@register_test_case(module_factory=lambda: ColumnStackBasicIntModule()) +def ColumnStackBasicIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 3, 4, low=0, high=2).bool(), + tu.randint(2, 3, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) + + +# ============================================================================== + + +class ColumnStack1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.column_stack([x, y]) + + +@register_test_case(module_factory=lambda: ColumnStack1dModule()) +def ColumnStack1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + + +# ============================================================================== + + +class ColumnStack0dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.column_stack([x, y]) + + +@register_test_case(module_factory=lambda: ColumnStack0dModule()) +def ColumnStack0dModule_basic(module, tu: TestUtils): + module.forward(torch.tensor(4.0), torch.tensor(1.0)) + + +# ============================================================================== + + class GatherModule(torch.nn.Module): def __init__(self): super().__init__()