Implement lowering of torch.aten.all.dim (#2873)

Lowering of torch.aten.all.dim to linalg.

Per PyTorch documentation:

> This function matches the behaviour of NumPy in returning output of
dtype bool for all supported dtypes except uint8. For uint8 the dtype of
output is uint8 itself.

Since there is no support for ui8 in torch-mlir currently
(https://github.com/llvm/torch-mlir/pull/1384#issuecomment-1260011334)
implementation returns failure for that case.
pull/2887/head
mmakevic 2024-02-07 21:34:52 +01:00 committed by GitHub
parent fc04bc7ee9
commit 32dbf99ce2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 0 deletions

View File

@ -277,6 +277,10 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenAllDimOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
}
op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr;
}
@ -357,6 +361,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenAllDimOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::MulIOp>(loc, self, result);
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
@ -447,6 +456,9 @@ private:
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
if (auto allOp = dyn_cast<AtenAllDimOp>(op))
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}
@ -535,6 +547,9 @@ private:
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
// No checks for all other reduction operations
return success();
}
@ -610,6 +625,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenProdDimIntOp>();
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);

View File

@ -7006,6 +7006,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
@ -11809,6 +11814,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int11 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -543,6 +543,9 @@ def atenone_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
def atenanydim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)
def atenalldim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)
def atenmaxdim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
@ -3766,6 +3769,13 @@ def atenanydim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
return self_dtype
return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def atenalldim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.uint8:
return self_dtype
return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenmin〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -316,6 +316,78 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceAllDimEmpty(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=0, keepdim=False)
@register_test_case(module_factory=lambda: ReduceAllDimEmpty())
def ReduceAllDimEmpty_basic(module, tu: TestUtils):
module.forward(torch.tensor([]))
# ==============================================================================
class ReduceAllDimFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1,-1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceAllDimFloat())
def ReduceAllDimFloat_basic(module, tu: TestUtils):
module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]]))
# ==============================================================================
class ReduceAllDimInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1,-1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceAllDimInt())
def ReduceAllDimInt_basic(module, tu: TestUtils):
module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32))
# ==============================================================================
class ReduceAllDimBool(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1,-1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=False)
@register_test_case(module_factory=lambda: ReduceAllDimBool())
def ReduceAllDimBool_basic(module, tu: TestUtils):
module.forward(torch.tensor([[True, False, True], [True, True, True]]))
# ==============================================================================
class ReduceMaxAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()