mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.all.dim (#2873)
Lowering of torch.aten.all.dim to linalg. Per PyTorch documentation: > This function matches the behaviour of NumPy in returning output of dtype bool for all supported dtypes except uint8. For uint8 the dtype of output is uint8 itself. Since there is no support for ui8 in torch-mlir currently (https://github.com/llvm/torch-mlir/pull/1384#issuecomment-1260011334) implementation returns failure for that case.pull/2887/head
parent
fc04bc7ee9
commit
32dbf99ce2
|
@ -277,6 +277,10 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|||
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
if (isa<AtenAllDimOp>(op)) {
|
||||
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in createInitElementForReduceOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -357,6 +361,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
|
||||
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
||||
return b.create<arith::AddFOp>(loc, pow, result);
|
||||
} else if (isa<AtenAllDimOp>(op)) {
|
||||
Value elem = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||
return b.create<arith::MulIOp>(loc, self, result);
|
||||
}
|
||||
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
|
||||
return nullptr;
|
||||
|
@ -447,6 +456,9 @@ private:
|
|||
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
|
||||
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);
|
||||
|
||||
if (auto allOp = dyn_cast<AtenAllDimOp>(op))
|
||||
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);
|
||||
|
||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||
}
|
||||
|
||||
|
@ -535,6 +547,9 @@ private:
|
|||
!elemType.isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only float types are valid for vector norm ops");
|
||||
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
|
||||
elemType.getIntOrFloatBitWidth() == 8)
|
||||
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
|
||||
// No checks for all other reduction operations
|
||||
return success();
|
||||
}
|
||||
|
@ -610,6 +625,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|||
target.addIllegalOp<AtenProdDimIntOp>();
|
||||
target.addIllegalOp<AtenMaxOp>();
|
||||
target.addIllegalOp<AtenMinOp>();
|
||||
target.addIllegalOp<AtenAllDimOp>();
|
||||
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
||||
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
|
|
|
@ -7006,6 +7006,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
|
@ -11809,6 +11814,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -543,6 +543,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
|
|||
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return reduced_shape, reduced_shape
|
||||
|
@ -3766,6 +3769,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
|
|||
return self_dtype
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
|
||||
def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.uint8:
|
||||
return self_dtype
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -316,6 +316,78 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllDimEmpty(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a, dim=0, keepdim=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllDimEmpty())
|
||||
def ReduceAllDimEmpty_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllDimFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a, dim=1, keepdim=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllDimFloat())
|
||||
def ReduceAllDimFloat_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllDimInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a, dim=1, keepdim=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllDimInt())
|
||||
def ReduceAllDimInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceAllDimBool(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.all(a, dim=1, keepdim=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAllDimBool())
|
||||
def ReduceAllDimBool_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[True, False, True], [True, True, True]]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxAlongDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue