diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f697d596e..7591493f8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14121,6 +14121,29 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [ }]; } +def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::hstack : (Tensor[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHstackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenHstackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ AllowsTypeRefinement ]> { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 836428d6e..545fdee26 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10639,6 +10639,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hstack\"(%arg0: !torch.list>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" +" %7 = func.call @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%6) : (!torch.list) -> !torch.list\n" +" %8 = torch.aten.append.t %0, %7 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = torch.aten.__getitem__.t %0, %int0 : !torch.list>, !torch.int -> !torch.list\n" +" %3 = torch.aten.len.t %2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int0) : (!torch.list>, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list>, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" return %5 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -15185,6 +15210,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hstack\"(%arg0: !torch.list>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f354374fe..b60eda351 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3813,6 +3813,58 @@ public: }; } // namespace +// Decompose `aten.hstack` into `aten.at_least1d` and `aten.cat`. +// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3908 +namespace { +class DecomposeAtenHstackOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHstackOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get SmallVector from Value. + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the tensor list is not from list construct"); + + // Execute AtenAtleast1dOp on every tensor inside tensors. + SmallVector atleast1dTensors; + for (auto tensor : tensors) { + std::optional tensorRank = getTensorRank(tensor); + + // Check if the tensor is already of rank >= 1. + if (*tensorRank < 1) { + auto atleast1dTensor = + rewriter.create(loc, tensor.getType(), tensor); + atleast1dTensors.push_back(atleast1dTensor); + } else { + atleast1dTensors.push_back(tensor); + } + } + + // Make Value list from atleast1dTensors variable. + auto elemType = cast(atleast1dTensors[0].getType()) + .getWithSizesAndDtype(std::nullopt, nullptr); + Value atleast1dTensorList = rewriter.create( + loc, Torch::ListType::get(elemType), atleast1dTensors); + + // Replace hstack with cat operator. + if (getTensorRank(atleast1dTensors[0]) == 1) + rewriter.replaceOpWithNewOp( + op, op.getType(), atleast1dTensorList, + rewriter.create(loc, rewriter.getI64IntegerAttr(0))); + else + rewriter.replaceOpWithNewOp( + op, op.getType(), atleast1dTensorList, + rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + + return success(); + } +}; +} // namespace + // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { @@ -9567,6 +9619,7 @@ public: addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 669753baa..aa81a68ca 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -380,6 +380,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a510ac186..0430ba9d5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1213,6 +1213,10 @@ STABLEHLO_PASS_SET = { "GridSamplerBasic4_basic", "GtFloatIntModule_basic", "GtIntModule_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", "IntFloatModule_basic", @@ -2215,6 +2219,11 @@ MAKE_FX_TOSA_PASS_SET = ( # failed to legalize operation 'torch.aten.rrelu_with_noise' "ElementwiseRreluEvalModule_basic", "ElementwiseRreluEvalStaticModule_basic", + # incompatible return type failure for tosa.concat. + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", # Shape Related failures "PrimListUnpackNumMismatchModule_basic", "ReshapeExpandModule_basic", @@ -2623,6 +2632,10 @@ ONNX_XFAIL_SET = { "GtFloatIntModule_basic", "GtIntModule_basic", "HardtanhBackward_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2b80f5ce1..3e1177500 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2159,6 +2159,19 @@ def aten〇atleast_2d〡shape(self: List[int]) -> List[int]: def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.stack(tensors, dim) + +@check_shape_function([ + Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case. +]) +def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: + + tensors_atleast1d = [aten〇atleast_1d〡shape(tensor) for tensor in tensors] + + if len(tensors_atleast1d[0]) == 1: + return upstream_shape_functions.cat(tensors_atleast1d, dim=0) + + return upstream_shape_functions.cat(tensors_atleast1d, dim=1) + def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self @@ -5325,6 +5338,23 @@ def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + + return promote_dtypes(ranks, dtypes) + @check_dtype_function( [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 2e4791dcb..9318ab6f2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1015,6 +1015,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): has_folder=True, ) emit("aten::stack : (Tensor[], int) -> (Tensor)") + emit("aten::hstack : (Tensor[]) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b33f8e3ee..03e16ab2c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1308,6 +1308,107 @@ def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils): # ============================================================================== +class HstackBasicIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4], torch.bool, True), + ([2, 3, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ] + ) + def forward(self, x, y, z): + return torch.ops.aten.hstack([x, y, z]) + + +@register_test_case(module_factory=lambda: HstackBasicIntModule()) +def HstackBasicIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 3, 4, low=0, high=2).bool(), + tu.randint(2, 3, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) + + +class HstackBasicFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6, 4], torch.int32, True), + ([2, 3, 4], torch.float64, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicFloatModule()) +def HstackBasicFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 6, 4).int(), + tu.rand(2, 3, 4).double(), + ) + + +class HstackBasicIntFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicIntFloatModule()) +def HstackBasicIntFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, 6, 4, 2, low=1, high=50).int(), + tu.rand(4, 3, 4, 2), + ) + + +class HstackBasicComplexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.complex64, True), + ([-1, -1, -1, -1], torch.complex128, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicComplexModule()) +def HstackBasicComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4, 6, 4, 2).type(torch.complex64), + tu.rand(4, 3, 4, 2).type(torch.complex128), + ) + + +# ============================================================================== + + class GatherModule(torch.nn.Module): def __init__(self): super().__init__()