mirror of https://github.com/llvm/torch-mlir
Support for prims collapse op (lowering to linalg) (#2572)
Steps taken: 1) add generator code to torch_ods_gen.py, run update_torch_ods.sh 2) add (custom) shape and type inference generator code to abstract_interp_lib_gen.py, run update_abstract_interp_lib.sh 3) Implement lowering to tensor.collapse_dims. Requires the `start` and `end` values to be constant, else lowering fails 4) Update xfail_sets.py (append to LTC_XFAIL_SET) after running /tools/e2e_test.sh --filter Collapse --verbose -c XX for all support backends (XX). Motivation: - Supporting the collapse operation will be useful for lowering of pixel_shuffle (see Issue #2559)pull/2563/head
parent
6be9789f9f
commit
e81282ae8f
|
@ -14185,6 +14185,31 @@ def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimsCollapseOp : Torch_Op<"prims.collapse", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prims::collapse : (Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a,
|
||||
Torch_IntType:$start,
|
||||
Torch_IntType:$end
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult PrimsCollapseOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void PrimsCollapseOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -1298,6 +1299,7 @@ public:
|
|||
// nll_loss_forward[i] = -(input[i][indi]);
|
||||
// TODO: `weight`operand is still to be taken care of.
|
||||
namespace {
|
||||
|
||||
class ConvertAtenNllLossForwardOp
|
||||
: public OpConversionPattern<AtenNllLossForwardOp> {
|
||||
public:
|
||||
|
@ -1757,6 +1759,71 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertPrimsCollapseOp : public OpConversionPattern<PrimsCollapseOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimsCollapseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto resultRankedTensorType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
// Collapse range must be statically known.
|
||||
int64_t startInt;
|
||||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&startInt)))
|
||||
return failure();
|
||||
|
||||
int64_t endInt;
|
||||
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&endInt)))
|
||||
return failure();
|
||||
|
||||
// Upstream MLIR is overly strict -- it fails verification if the
|
||||
// collapse_shape is the identity op (i.e. when no dimensions are
|
||||
// collapsed). We manually fold this case here.
|
||||
if (startInt == endInt) {
|
||||
rewriter.replaceOp(op, adaptor.getA());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> associations;
|
||||
associations.reserve(resultRankedTensorType.getRank());
|
||||
|
||||
// An example of is where input shape is [3,4,5,6] and
|
||||
// start = 1, and end = 2. The collapsed shape is then [3,4*5,6],
|
||||
// with reassociation indices of [0], [1,2], and [3].
|
||||
|
||||
// Append the singleton dimensions before the collapsed dimensions.
|
||||
for (unsigned i = 0; i < startInt; ++i) {
|
||||
associations.push_back(ReassociationIndices{i});
|
||||
}
|
||||
|
||||
// Append the collapsed dimensions.
|
||||
ReassociationIndices collapseDims(endInt + 1 - startInt);
|
||||
std::iota(collapseDims.begin(), collapseDims.end(), startInt);
|
||||
associations.push_back(collapseDims);
|
||||
|
||||
// Append the singleton dimensions after the collapsed dimensions.
|
||||
for (int i = endInt + 1; i < aRankedTensorType.getRank(); ++i) {
|
||||
associations.push_back(ReassociationIndices{i});
|
||||
}
|
||||
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
op, resultRankedTensorType, adaptor.getA(), associations);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTensorStaticInfoCastOp
|
||||
: public OpConversionPattern<TensorStaticInfoCastOp> {
|
||||
|
@ -1805,6 +1872,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBatchNormOp>();
|
||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<PrimsCollapseOp>();
|
||||
patterns.add<ConvertPrimsCollapseOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenNllLossBackwardOp>();
|
||||
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
||||
|
|
|
@ -6461,6 +6461,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: end out of bounds\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: start out of bounds\"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.ge.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.ge.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.le.int %arg1, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" torch.prim.Loop %arg1, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.append.t %7, %15 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %8 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.aten.__range_length %arg1, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.prim.Loop %9, %true, init(%int1) {\n"
|
||||
" ^bb0(%arg3: !torch.int, %arg4: !torch.int):\n"
|
||||
" %15 = torch.aten.__derive_index %arg3, %arg1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %17 = torch.aten.mul.int %arg4, %16 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
|
||||
" %11 = torch.aten.append.t %7, %10 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %14 = torch.aten.__range_length %12, %13, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop %14, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %15 = torch.aten.__derive_index %arg3, %12, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %17 = torch.aten.append.t %7, %16 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" return %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -11295,6 +11369,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prims.collapse\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -1355,6 +1355,11 @@ LTC_CRASHING_SET = {
|
|||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseRank1DynamicModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"CollapsePartialDynamicModule_basic",
|
||||
"CollapseFullDynamicModule_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"_Convolution2DAllFalseModule_basic",
|
||||
|
|
|
@ -177,6 +177,8 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
|
|||
assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2"
|
||||
return self[:dim] + [self[dim] // 2] + self[dim+1:]
|
||||
|
||||
|
||||
|
||||
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -204,6 +206,40 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1
|
|||
def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(a)
|
||||
|
||||
def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
|
||||
# Obtained through trial and error on a few examples in PyTorch:
|
||||
assert start <= len(a), "start out of bounds"
|
||||
assert end <= len(a), "end out of bounds"
|
||||
assert start >= 0, "start out of bounds"
|
||||
assert end >= 0, "end out of bounds"
|
||||
assert start <= end, "start must be less than or equal to end"
|
||||
|
||||
# Example:
|
||||
#
|
||||
# torch._prims.collapse(torch.empty(2,3,4), 1,2).shape
|
||||
# is
|
||||
# torch.Size([2, 12])
|
||||
|
||||
collapsed: List[int] = []
|
||||
for i in range(start):
|
||||
collapsed.append(a[i])
|
||||
|
||||
# For the example, here collapsed is [2]
|
||||
combined = 1
|
||||
for i in range(start, end + 1):
|
||||
combined *= a[i]
|
||||
|
||||
collapsed.append(combined)
|
||||
|
||||
# For the example, here collapsed is [2, 12]
|
||||
|
||||
for i in range(end + 1, len(a)):
|
||||
collapsed.append(a[i])
|
||||
|
||||
# For the example, here collapsed is [2, 12]
|
||||
|
||||
return collapsed
|
||||
|
||||
def aten〇to〇dtype〡shape(self: List[int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -905,6 +941,7 @@ def aten〇squeeze〇dim〡shape(self: List[int], dim: int) -> List[int]:
|
|||
def prims〇squeeze〡shape(a: List[int], dimensions: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.squeeze_dims(a, dimensions)
|
||||
|
||||
|
||||
def prims〇view_of〡shape(a: List[int]) -> List[int]:
|
||||
return a
|
||||
|
||||
|
@ -3693,6 +3730,12 @@ def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int]
|
|||
return a_dtype
|
||||
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, start=0, end = 0))
|
||||
def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int) -> int:
|
||||
a_rank, a_dtype = a_rank_dtype
|
||||
return a_dtype
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main
|
||||
|
|
|
@ -817,6 +817,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
||||
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
|
||||
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
|
||||
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
|
||||
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
||||
|
|
|
@ -341,6 +341,7 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
|
@ -122,6 +122,105 @@ class ViewDynamicExpandModule(torch.nn.Module):
|
|||
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 30, 384))
|
||||
|
||||
# ==============================================================================
|
||||
#
|
||||
class CollapseAllDimensionsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2,2,2,2], torch.float32, True)])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.collapse(a, 0, 3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: CollapseAllDimensionsModule())
|
||||
def CollapseAllDimensionsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2,2,2,2))
|
||||
|
||||
# ==============================================================================
|
||||
#
|
||||
class CollapseRank1DynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True)])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.collapse(a, 0, 0)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: CollapseRank1DynamicModule())
|
||||
def CollapseRank1DynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
#
|
||||
class CollapseStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2,3,4], torch.float32, True)])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.collapse(a, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: CollapseStaticModule())
|
||||
def CollapseStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2,3,4))
|
||||
|
||||
# ==============================================================================
|
||||
#
|
||||
class CollapsePartialDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1,4,5], torch.float32, True)])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.collapse(a, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: CollapsePartialDynamicModule())
|
||||
def CollapsePartialDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2,3,4,5))
|
||||
|
||||
class CollapseFullDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1,-1], torch.float32, True)])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.prims.collapse(a, 0,1)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: CollapseFullDynamicModule())
|
||||
def CollapseFullDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2,3,5))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue