Add aten.min.dim to linalg lowering (#2600)

pull/2612/head
Frederik Harwath 2023-12-05 16:16:35 +01:00 committed by GitHub
parent d0b49a912e
commit 6248216dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 223 additions and 60 deletions

View File

@ -30,70 +30,80 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic // Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an
// op, producing two output buffers. // linalg.indexed_generic op, producing two output buffers.
// //
// The first output buffer contains the maximum value found. It is initialized // The first output buffer contains the maximum (minium) value found. It is
// to the minimum representable value of the input element type. // initialized to the minimum (maximum) representable value of the input
// element type.
// //
// The second output buffer contains the index of the found maximum value. It is // The second output buffer contains the index of the found maximum (minimum)
// initialized to 0 and is resulting integer type. // value. It is initialized to 0 and is resulting integer type.
// //
// The indexed_generic op updates both the maximum value and index if the // The indexed_generic op updates both the maximum (minimum) value and index
// current value exceeds the running max. // if the current value exceeds the running max (min).
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> { template <typename OpTy>
class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
using OpConversionPattern<OpTy>::getTypeConverter;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult LogicalResult
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor, matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
static_assert(std::is_same<OpTy, AtenMaxDimOp>() ||
std::is_same<OpTy, AtenMinDimOp>());
constexpr bool isMax = std::is_same<OpTy, AtenMaxDimOp>();
const llvm::StringRef opName = op->getName().getStringRef();
Location loc = maxDimOp.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
RankedTensorType valResultType = RankedTensorType valResultType =
getTypeConverter() getTypeConverter()
->convertType(maxDimOp.getResult(0).getType()) ->convertType(op.getResult(0).getType())
.cast<RankedTensorType>(); .template cast<RankedTensorType>();
RankedTensorType idxResultType = RankedTensorType idxResultType =
getTypeConverter() this->getTypeConverter()
->convertType(maxDimOp.getResult(1).getType()) ->convertType(op.getResult(1).getType())
.cast<RankedTensorType>(); .template cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType(); Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<IntegerType>()) if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, op, opName + " to linalg.* requires integer-like result type");
"aten.max_dim to linalg.* requires integer-like result type");
bool keepDim = false; bool keepDim = false;
if (!matchPattern(maxDimOp.getKeepdim(), m_TorchConstantBool(&keepDim))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim requires boolean value for keepdim"); op, opName + " requires boolean value for keepdim");
int64_t dim; int64_t dim;
if (!matchPattern(maxDimOp.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim"); op, opName + " to linalg.* requires int value for Dim");
dim = toPositiveDim(dim, inputType.getRank()); dim = toPositiveDim(dim, inputType.getRank());
if (!isValidDim(dim, inputType.getRank())) if (!isValidDim(dim, inputType.getRank()))
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim"); return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
Type inElementType = inputType.getElementType(); Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) { if (!inElementType.isa<mlir::FloatType>()) {
if (inElementType.isa<mlir::IntegerType>()) { if (inElementType.isa<mlir::IntegerType>()) {
auto integerTy = maxDimOp.getSelf() auto integerTy = op.getSelf()
.getType() .getType()
.cast<BaseTensorType>() .template cast<BaseTensorType>()
.getDtype() .getDtype()
.dyn_cast<mlir::IntegerType>(); .template dyn_cast<mlir::IntegerType>();
if (integerTy.isUnsigned()) if (integerTy.isUnsigned())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires input element type " op, opName + " to linalg.* requires input element type "
"to be signed in case of integer"); "to be signed in case of integer");
} else { } else {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer " op, opName + " to linalg.* requires Float or Integer "
"input element type"); "input element type");
} }
} }
@ -112,29 +122,29 @@ public:
Value filledTensorIdx = Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, idxElementType); createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
// Second fill the output buffer for the running max. // Second fill the output buffer for the running max or min.
Value initTensorMax = rewriter.create<tensor::EmptyOp>( Value initTensorVal = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), inElementType); loc, getAsOpFoldResult(resultShape), inElementType);
Value fillValueMax; Value fillValue;
if (inElementType.isa<mlir::FloatType>()) { if (inElementType.isa<mlir::FloatType>()) {
fillValueMax = rewriter.create<arith::ConstantOp>( fillValue = rewriter.create<arith::ConstantOp>(
loc, loc,
rewriter.getFloatAttr( rewriter.getFloatAttr(
inElementType, inElementType,
APFloat::getInf( APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), inElementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true))); /*Negative=*/isMax)));
} else { } else {
fillValueMax = rewriter.create<arith::ConstantOp>( auto width = inElementType.cast<mlir::IntegerType>().getWidth();
loc, rewriter.getIntegerAttr( auto init = isMax ? APSInt::getSignedMinValue(width)
inElementType, : APSInt::getSignedMaxValue(width);
APSInt::getSignedMinValue( fillValue = rewriter.create<arith::ConstantOp>(
inElementType.cast<mlir::IntegerType>().getWidth()))); loc, rewriter.getIntegerAttr(inElementType, init));
} }
Value filledTensorMax = Value filledTensorVal =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax) rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal)
.result(); .result();
// Create the affine expressions that will be used to // Create the affine expressions that will be used to
@ -161,8 +171,8 @@ public:
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, loc,
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}), ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}),
input, ValueRange({filledTensorMax, filledTensorIdx}), maps, input, ValueRange({filledTensorVal, filledTensorIdx}), maps,
iteratorTypes, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) { ValueRange blockArgs) {
@ -174,33 +184,51 @@ public:
nestedLoc, oldIndex.getType(), nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim)); rewriter.create<linalg::IndexOp>(loc, dim));
Value resultMax, predicate; Value resultVal, predicate;
if (inElementType.isa<mlir::FloatType>()) { if (inElementType.isa<mlir::FloatType>()) {
resultMax = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue, arith::CmpFPredicate predType;
oldValue); if constexpr (isMax) {
predicate = rewriter.create<arith::CmpFOp>( predType = arith::CmpFPredicate::OGT;
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); resultVal = rewriter.create<arith::MaximumFOp>(
nestedLoc, newValue, oldValue);
} else {
predType = arith::CmpFPredicate::OLT;
resultVal = rewriter.create<arith::MinimumFOp>(
nestedLoc, newValue, oldValue);
}
predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
newValue, oldValue);
} else { } else {
resultMax = arith::CmpIPredicate predType;
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue); if constexpr (isMax) {
predicate = rewriter.create<arith::CmpIOp>( predType = arith::CmpIPredicate::sgt;
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
oldValue);
} else {
predType = arith::CmpIPredicate::slt;
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
oldValue);
}
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
newValue, oldValue);
} }
auto resultIndex = rewriter.create<arith::SelectOp>( auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex); nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>( nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultMax, resultIndex})); nestedLoc, ValueRange({resultVal, resultIndex}));
}); });
// This cast is required to fix the shape in the case of keepDim=True // This cast is required to fix the shape in the case of keepDim=True
Value maxValuesCast = rewriter.create<tensor::CastOp>( Value valuesCast = rewriter.create<tensor::CastOp>(
loc, valResultType, linalgOp.getResult(0)); loc, valResultType, linalgOp.getResult(0));
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType, Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1)); linalgOp.getResult(1));
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast}); rewriter.replaceOp(op, {valuesCast, idxCast});
return success(); return success();
} }
}; };
} // namespace } // namespace
static Value createInitElementForReduceOp(OpBuilder &b, Location loc, static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
@ -574,7 +602,9 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
ConversionTarget &target) { ConversionTarget &target) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxDimOp>(); target.addIllegalOp<AtenMaxDimOp>();
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context); patterns.add<ConvertAtenMinMaxDimOp<AtenMaxDimOp>>(typeConverter, context);
target.addIllegalOp<AtenMinDimOp>();
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
target.addIllegalOp<AtenSumOp>(); target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenSumDimIntListOp>(); target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenProdDimIntOp>(); target.addIllegalOp<AtenProdDimIntOp>();

View File

@ -6872,6 +6872,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n" " %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n" " return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.min.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n" " %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
@ -10691,6 +10697,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n" " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n" " return %1 : !torch.tuple<int, int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple<int, int>) -> !torch.int\n"
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n" " %false = torch.constant.bool false\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"

View File

@ -18,7 +18,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic", "IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic", "IscloseStaticModuleTrue_basic"
} }
TORCHDYNAMO_XFAIL_SET = { TORCHDYNAMO_XFAIL_SET = {
@ -69,6 +69,7 @@ TORCHDYNAMO_XFAIL_SET = {
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) #ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
"UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dDynamicFactor_basic",
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
#ERROR: value (-56) is not equal to golden value (200) #ERROR: value (-56) is not equal to golden value (200)
"AtenIntTensorByteDtypeModule_basic", "AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor) # ERROR: assert isinstance(e, FakeTensor)

View File

@ -458,6 +458,10 @@ def atenmaxdim〡shape(self: List[int], dim: int, keepdim: bool = False) -
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape return reduced_shape, reduced_shape
def atenmindim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
def atenamax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: def atenamax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
@ -3286,6 +3290,10 @@ def atenamax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
def atenmaxdim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: def atenmaxdim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return atenmax〡dtype(self_rank_dtype), torch.int64 return atenmax〡dtype(self_rank_dtype), torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def atenmindim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return atenmin〡dtype(self_rank_dtype), torch.int64
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype( _check_tensors_with_the_same_dtype(
num_of_tensors=1, num_of_tensors=1,

View File

@ -14,6 +14,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormBackwardModule_basic", "NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic", "QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
} }

View File

@ -335,6 +335,117 @@ def ReduceMaxAlongDim_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ReduceMinAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1)[0]
@register_test_case(module_factory=lambda: ReduceMinAlongDim())
def ReduceMinAlongDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
class ReduceMinAlongDimSignedInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1)
@register_test_case(module_factory=lambda: ReduceMinAlongDimSignedInt())
def ReduceMinAlongDimSignedInt_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
# ==============================================================================
class ReduceMinAlongDimUnsignedInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.uint8, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1)
@register_test_case(module_factory=lambda: ReduceMinAlongDimUnsignedInt())
def ReduceMinAlongDimUnsignedInt_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8))
# ==============================================================================
class ReduceMinAlongDimNegative(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1)[0]
@register_test_case(module_factory=lambda: ReduceMinAlongDimNegative())
def ReduceMinAlongDimNegative_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
# ==============================================================================
class ReduceMinKeepDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1, keepdim=True)[1]
@register_test_case(module_factory=lambda: ReduceMinKeepDim())
def ReduceMinKeepDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMinKeepDimReturnBoth(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.min(a, 1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceMinKeepDimReturnBoth())
def ReduceMinKeepDimReturnBoth_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
# ==============================================================================
class ReduceMaxAlongDimSignedInt(torch.nn.Module): class ReduceMaxAlongDimSignedInt(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()