Implement lowering of torch.aten.atleast_1d (#3498)

This operator is necessary in order to implement torch.aten.vstack.
Which will be added in a future PR.
pull/3550/head
pkapris-syrmia 2024-07-17 14:50:30 +02:00 committed by GitHub
parent 574143448b
commit b59efc75f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 141 additions and 0 deletions

View File

@ -10148,6 +10148,29 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
}];
}
def Torch_AtenAtleast1dOp : Torch_Op<"aten.atleast_1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::atleast_1d : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtleast1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAtleast1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -10334,6 +10334,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -14517,6 +14530,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.atleast_1d\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"

View File

@ -1476,6 +1476,44 @@ public:
};
} // namespace
namespace {
// Decompose aten.atleast_1d into: aten.reshape. See
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2591
// def atleast_1d(
// arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args:
// TensorLikeType
// ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
// """Refrence implementation of :func:`torch.atleast_1d`."""
// if not args and isinstance(arg, collections.abc.Sequence):
// args_ = arg
// else:
// assert not isinstance(arg, collections.abc.Sequence)
// args_ = (arg,) + args
// res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
// return res if len(res) > 1 else res[0]
class DecomposeAtenAtleast1dOp : public OpRewritePattern<AtenAtleast1dOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAtleast1dOp op,
PatternRewriter &rewriter) const override {
Value input = op.getSelf();
Location loc = op.getLoc();
Type opType = op.getType();
auto inpType = cast<BaseTensorType>(input.getType());
SmallVector<int64_t> inputShape(inpType.getSizes());
if (inputShape.empty()) {
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenUnsqueezeOp>(op, opType, input, zero);
return success();
}
rewriter.replaceOp(op, input);
return success();
}
};
} // namespace
namespace {
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
// operation and permute operation. Currently, this pass doesn't support
@ -8863,6 +8901,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);

View File

@ -394,6 +394,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAtleast1dOp>();
target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<AtenTraceOp>();
target.addIllegalOp<AtenAddmmOp>();

View File

@ -840,6 +840,8 @@ STABLEHLO_PASS_SET = {
"TensorSplitSections_ListUnpackModule_basic",
"EmptyModule_uint8",
"TypeConversionUint8ToF32Module_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
@ -1507,6 +1509,8 @@ TOSA_PASS_SET = {
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -1985,6 +1989,8 @@ MAKE_FX_TOSA_PASS_SET = (
"AtenLinear1D_basic",
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"Atleast1dModule0dInput_basic",
"Atleast1dModule1dInput_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",

View File

@ -2057,6 +2057,12 @@ def atenindexTensor_hacked_twin〡shape(self: List[int], indices: List[Lis
def atencat〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.cat(tensors, dim)
def atenatleast_1d〡shape(self: List[int]) -> List[int]:
if len(self) == 0:
return [1]
else:
return self
def atenstack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)
@ -5095,6 +5101,11 @@ def atencat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0)
dtypes.append(tensor_dtype)
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenatleast_1d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])

View File

@ -783,6 +783,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::atleast_1d : (Tensor) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::trace : (Tensor) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")

View File

@ -1461,3 +1461,46 @@ def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils):
input = torch.arange(20).to(dtype=torch.float32)
input = input.reshape((1, 1, 4, 5))
module.forward(input)
# ==============================================================================
class Atleast1dModule0dInput(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.atleast_1d(x)
@register_test_case(module_factory=lambda: Atleast1dModule0dInput())
def Atleast1dModule0dInput_basic(module, tu: TestUtils):
module.forward(tu.rand())
class Atleast1dModule1dInput(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([4], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.atleast_1d(x)
@register_test_case(module_factory=lambda: Atleast1dModule1dInput())
def Atleast1dModule1dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(4))