From e81282ae8f23a34b5a7dbef94bea3dfdf945ebcd Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 15 Nov 2023 08:34:38 -0800 Subject: [PATCH] Support for prims collapse op (lowering to linalg) (#2572) Steps taken: 1) add generator code to torch_ods_gen.py, run update_torch_ods.sh 2) add (custom) shape and type inference generator code to abstract_interp_lib_gen.py, run update_abstract_interp_lib.sh 3) Implement lowering to tensor.collapse_dims. Requires the `start` and `end` values to be constant, else lowering fails 4) Update xfail_sets.py (append to LTC_XFAIL_SET) after running /tools/e2e_test.sh --filter Collapse --verbose -c XX for all support backends (XX). Motivation: - Supporting the collapse operation will be useful for lowering of pixel_shuffle (see Issue #2559) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++ .../TorchToLinalg/Uncategorized.cpp | 71 +++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 78 +++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 5 + .../build_tools/abstract_interp_lib_gen.py | 43 ++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 1 + .../test_suite/reshape_like.py | 99 +++++++++++++++++++ 8 files changed, 323 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 37ffc8893..8976fe4c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14185,6 +14185,31 @@ def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [ }]; } +def Torch_PrimsCollapseOp : Torch_Op<"prims.collapse", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::collapse : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$a, + Torch_IntType:$start, + Torch_IntType:$end + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsCollapseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsCollapseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 7f6f55e87..ee968daff 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" +#include using namespace mlir; using namespace mlir::torch; @@ -1298,6 +1299,7 @@ public: // nll_loss_forward[i] = -(input[i][indi]); // TODO: `weight`operand is still to be taken care of. namespace { + class ConvertAtenNllLossForwardOp : public OpConversionPattern { public: @@ -1757,6 +1759,71 @@ public: }; } // namespace +namespace { +class ConvertPrimsCollapseOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PrimsCollapseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto aRankedTensorType = adaptor.getA().getType().cast(); + const TypeConverter *typeConverter = getTypeConverter(); + + auto resultRankedTensorType = + typeConverter->convertType(op.getType()).cast(); + + // Collapse range must be statically known. + int64_t startInt; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&startInt))) + return failure(); + + int64_t endInt; + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&endInt))) + return failure(); + + // Upstream MLIR is overly strict -- it fails verification if the + // collapse_shape is the identity op (i.e. when no dimensions are + // collapsed). We manually fold this case here. + if (startInt == endInt) { + rewriter.replaceOp(op, adaptor.getA()); + return success(); + } + + SmallVector associations; + associations.reserve(resultRankedTensorType.getRank()); + + // An example of is where input shape is [3,4,5,6] and + // start = 1, and end = 2. The collapsed shape is then [3,4*5,6], + // with reassociation indices of [0], [1,2], and [3]. + + // Append the singleton dimensions before the collapsed dimensions. + for (unsigned i = 0; i < startInt; ++i) { + associations.push_back(ReassociationIndices{i}); + } + + // Append the collapsed dimensions. + ReassociationIndices collapseDims(endInt + 1 - startInt); + std::iota(collapseDims.begin(), collapseDims.end(), startInt); + associations.push_back(collapseDims); + + // Append the singleton dimensions after the collapsed dimensions. + for (int i = endInt + 1; i < aRankedTensorType.getRank(); ++i) { + associations.push_back(ReassociationIndices{i}); + } + + + rewriter.replaceOpWithNewOp( + op, resultRankedTensorType, adaptor.getA(), associations); + + return success(); + } +}; +} // namespace + namespace { class ConvertTensorStaticInfoCastOp : public OpConversionPattern { @@ -1805,6 +1872,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c29c2f7b3..92f8c8006 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6461,6 +6461,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" +" %str_0 = torch.constant.str \"AssertionError: end out of bounds\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: start out of bounds\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ge.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.ge.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.le.int %arg1, %arg2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %arg1, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %15 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.append.t %7, %15 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %8 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.__range_length %arg1, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.Loop %9, %true, init(%int1) {\n" +" ^bb0(%arg3: !torch.int, %arg4: !torch.int):\n" +" %15 = torch.aten.__derive_index %arg3, %arg1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.mul.int %arg4, %16 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %11 = torch.aten.append.t %7, %10 : !torch.list, !torch.int -> !torch.list\n" +" %12 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %14 = torch.aten.__range_length %12, %13, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %14, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %15 = torch.aten.__derive_index %arg3, %12, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.append.t %7, %16 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %7 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11295,6 +11369,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.collapse\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index aaabd18e2..2d26a3687 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1355,6 +1355,11 @@ LTC_CRASHING_SET = { } LTC_XFAIL_SET = { + "CollapseAllDimensionsModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseFullDynamicModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", "_Convolution2DAllFalseModule_basic", diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index aa56b4b57..0a2c8ace6 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -177,6 +177,8 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]: assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2" return self[:dim] + [self[dim] // 2] + self[dim+1:] + + def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) @@ -204,6 +206,40 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1 def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) +def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: + # Obtained through trial and error on a few examples in PyTorch: + assert start <= len(a), "start out of bounds" + assert end <= len(a), "end out of bounds" + assert start >= 0, "start out of bounds" + assert end >= 0, "end out of bounds" + assert start <= end, "start must be less than or equal to end" + + # Example: + # + # torch._prims.collapse(torch.empty(2,3,4), 1,2).shape + # is + # torch.Size([2, 12]) + + collapsed: List[int] = [] + for i in range(start): + collapsed.append(a[i]) + + # For the example, here collapsed is [2] + combined = 1 + for i in range(start, end + 1): + combined *= a[i] + + collapsed.append(combined) + + # For the example, here collapsed is [2, 12] + + for i in range(end + 1, len(a)): + collapsed.append(a[i]) + + # For the example, here collapsed is [2, 12] + + return collapsed + def aten〇to〇dtype〡shape(self: List[int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -905,6 +941,7 @@ def aten〇squeeze〇dim〡shape(self: List[int], dim: int) -> List[int]: def prims〇squeeze〡shape(a: List[int], dimensions: List[int]) -> List[int]: return upstream_shape_functions.squeeze_dims(a, dimensions) + def prims〇view_of〡shape(a: List[int]) -> List[int]: return a @@ -3693,6 +3730,12 @@ def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int] return a_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, start=0, end = 0)) +def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int) -> int: + a_rank, a_dtype = a_rank_dtype + return a_dtype + + # ============================================================================== # Main diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4648506ac..4c114eda8 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -817,6 +817,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") + emit("prims::collapse : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 00b07cbb6..74c3050f6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -341,6 +341,7 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index e2af562a5..a0ee6221b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -122,6 +122,105 @@ class ViewDynamicExpandModule(torch.nn.Module): def ViewDynamicExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 30, 384)) +# ============================================================================== +# +class CollapseAllDimensionsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2,2,2,2], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.collapse(a, 0, 3) + + +@register_test_case( + module_factory=lambda: CollapseAllDimensionsModule()) +def CollapseAllDimensionsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2,2,2,2)) + +# ============================================================================== +# +class CollapseRank1DynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.collapse(a, 0, 0) + +@register_test_case( + module_factory=lambda: CollapseRank1DynamicModule()) +def CollapseRank1DynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5)) + +# ============================================================================== +# +class CollapseStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2,3,4], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.collapse(a, 1, 2) + + +@register_test_case( + module_factory=lambda: CollapseStaticModule()) +def CollapseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2,3,4)) + +# ============================================================================== +# +class CollapsePartialDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1,4,5], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.collapse(a, 1, 2) + + +@register_test_case( + module_factory=lambda: CollapsePartialDynamicModule()) +def CollapsePartialDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2,3,4,5)) + +class CollapseFullDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True)]) + + def forward(self, a): + return torch.ops.prims.collapse(a, 0,1) + + +@register_test_case( + module_factory=lambda: CollapseFullDynamicModule()) +def CollapseFullDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2,3,5)) + + + # ============================================================================== class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):