mirror of https://github.com/llvm/torch-mlir
Add aten.min.dim to linalg lowering (#2600)
parent
d0b49a912e
commit
6248216dca
|
@ -30,70 +30,80 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
|
// Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an
|
||||||
// op, producing two output buffers.
|
// linalg.indexed_generic op, producing two output buffers.
|
||||||
//
|
//
|
||||||
// The first output buffer contains the maximum value found. It is initialized
|
// The first output buffer contains the maximum (minium) value found. It is
|
||||||
// to the minimum representable value of the input element type.
|
// initialized to the minimum (maximum) representable value of the input
|
||||||
|
// element type.
|
||||||
//
|
//
|
||||||
// The second output buffer contains the index of the found maximum value. It is
|
// The second output buffer contains the index of the found maximum (minimum)
|
||||||
// initialized to 0 and is resulting integer type.
|
// value. It is initialized to 0 and is resulting integer type.
|
||||||
//
|
//
|
||||||
// The indexed_generic op updates both the maximum value and index if the
|
// The indexed_generic op updates both the maximum (minimum) value and index
|
||||||
// current value exceeds the running max.
|
// if the current value exceeds the running max (min).
|
||||||
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
|
template <typename OpTy>
|
||||||
|
class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
using OpConversionPattern<OpTy>::getTypeConverter;
|
||||||
|
|
||||||
|
using OpAdaptor = typename OpTy::Adaptor;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
|
matchAndRewrite(OpTy op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
static_assert(std::is_same<OpTy, AtenMaxDimOp>() ||
|
||||||
|
std::is_same<OpTy, AtenMinDimOp>());
|
||||||
|
constexpr bool isMax = std::is_same<OpTy, AtenMaxDimOp>();
|
||||||
|
const llvm::StringRef opName = op->getName().getStringRef();
|
||||||
|
|
||||||
Location loc = maxDimOp.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
RankedTensorType valResultType =
|
RankedTensorType valResultType =
|
||||||
getTypeConverter()
|
getTypeConverter()
|
||||||
->convertType(maxDimOp.getResult(0).getType())
|
->convertType(op.getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
|
|
||||||
RankedTensorType idxResultType =
|
RankedTensorType idxResultType =
|
||||||
getTypeConverter()
|
this->getTypeConverter()
|
||||||
->convertType(maxDimOp.getResult(1).getType())
|
->convertType(op.getResult(1).getType())
|
||||||
.cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType =
|
||||||
|
input.getType().template cast<RankedTensorType>();
|
||||||
Type idxElementType = idxResultType.getElementType();
|
Type idxElementType = idxResultType.getElementType();
|
||||||
if (!idxElementType.isa<IntegerType>())
|
if (!idxElementType.isa<IntegerType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
maxDimOp,
|
op, opName + " to linalg.* requires integer-like result type");
|
||||||
"aten.max_dim to linalg.* requires integer-like result type");
|
|
||||||
|
|
||||||
bool keepDim = false;
|
bool keepDim = false;
|
||||||
if (!matchPattern(maxDimOp.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
maxDimOp, "aten.max_dim requires boolean value for keepdim");
|
op, opName + " requires boolean value for keepdim");
|
||||||
|
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(maxDimOp.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
|
op, opName + " to linalg.* requires int value for Dim");
|
||||||
dim = toPositiveDim(dim, inputType.getRank());
|
dim = toPositiveDim(dim, inputType.getRank());
|
||||||
if (!isValidDim(dim, inputType.getRank()))
|
if (!isValidDim(dim, inputType.getRank()))
|
||||||
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
|
||||||
Type inElementType = inputType.getElementType();
|
Type inElementType = inputType.getElementType();
|
||||||
if (!inElementType.isa<mlir::FloatType>()) {
|
if (!inElementType.isa<mlir::FloatType>()) {
|
||||||
if (inElementType.isa<mlir::IntegerType>()) {
|
if (inElementType.isa<mlir::IntegerType>()) {
|
||||||
auto integerTy = maxDimOp.getSelf()
|
auto integerTy = op.getSelf()
|
||||||
.getType()
|
.getType()
|
||||||
.cast<BaseTensorType>()
|
.template cast<BaseTensorType>()
|
||||||
.getDtype()
|
.getDtype()
|
||||||
.dyn_cast<mlir::IntegerType>();
|
.template dyn_cast<mlir::IntegerType>();
|
||||||
if (integerTy.isUnsigned())
|
if (integerTy.isUnsigned())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
maxDimOp, "aten.max_dim to linalg.* requires input element type "
|
op, opName + " to linalg.* requires input element type "
|
||||||
"to be signed in case of integer");
|
"to be signed in case of integer");
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
|
op, opName + " to linalg.* requires Float or Integer "
|
||||||
"input element type");
|
"input element type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -112,29 +122,29 @@ public:
|
||||||
Value filledTensorIdx =
|
Value filledTensorIdx =
|
||||||
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
|
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
|
||||||
|
|
||||||
// Second fill the output buffer for the running max.
|
// Second fill the output buffer for the running max or min.
|
||||||
Value initTensorMax = rewriter.create<tensor::EmptyOp>(
|
Value initTensorVal = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(resultShape), inElementType);
|
loc, getAsOpFoldResult(resultShape), inElementType);
|
||||||
|
|
||||||
Value fillValueMax;
|
Value fillValue;
|
||||||
if (inElementType.isa<mlir::FloatType>()) {
|
if (inElementType.isa<mlir::FloatType>()) {
|
||||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
loc,
|
loc,
|
||||||
rewriter.getFloatAttr(
|
rewriter.getFloatAttr(
|
||||||
inElementType,
|
inElementType,
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*Negative=*/true)));
|
/*Negative=*/isMax)));
|
||||||
} else {
|
} else {
|
||||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
||||||
loc, rewriter.getIntegerAttr(
|
auto init = isMax ? APSInt::getSignedMinValue(width)
|
||||||
inElementType,
|
: APSInt::getSignedMaxValue(width);
|
||||||
APSInt::getSignedMinValue(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
inElementType.cast<mlir::IntegerType>().getWidth())));
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value filledTensorMax =
|
Value filledTensorVal =
|
||||||
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
|
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal)
|
||||||
.result();
|
.result();
|
||||||
|
|
||||||
// Create the affine expressions that will be used to
|
// Create the affine expressions that will be used to
|
||||||
|
@ -161,8 +171,8 @@ public:
|
||||||
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
|
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
|
||||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
loc,
|
loc,
|
||||||
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
|
ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}),
|
||||||
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
|
input, ValueRange({filledTensorVal, filledTensorIdx}), maps,
|
||||||
iteratorTypes,
|
iteratorTypes,
|
||||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||||
ValueRange blockArgs) {
|
ValueRange blockArgs) {
|
||||||
|
@ -174,33 +184,51 @@ public:
|
||||||
nestedLoc, oldIndex.getType(),
|
nestedLoc, oldIndex.getType(),
|
||||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||||
|
|
||||||
Value resultMax, predicate;
|
Value resultVal, predicate;
|
||||||
if (inElementType.isa<mlir::FloatType>()) {
|
if (inElementType.isa<mlir::FloatType>()) {
|
||||||
resultMax = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
|
arith::CmpFPredicate predType;
|
||||||
oldValue);
|
if constexpr (isMax) {
|
||||||
predicate = rewriter.create<arith::CmpFOp>(
|
predType = arith::CmpFPredicate::OGT;
|
||||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
resultVal = rewriter.create<arith::MaximumFOp>(
|
||||||
|
nestedLoc, newValue, oldValue);
|
||||||
} else {
|
} else {
|
||||||
resultMax =
|
predType = arith::CmpFPredicate::OLT;
|
||||||
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
|
resultVal = rewriter.create<arith::MinimumFOp>(
|
||||||
predicate = rewriter.create<arith::CmpIOp>(
|
nestedLoc, newValue, oldValue);
|
||||||
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
|
}
|
||||||
|
|
||||||
|
predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
|
||||||
|
newValue, oldValue);
|
||||||
|
} else {
|
||||||
|
arith::CmpIPredicate predType;
|
||||||
|
if constexpr (isMax) {
|
||||||
|
predType = arith::CmpIPredicate::sgt;
|
||||||
|
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
} else {
|
||||||
|
predType = arith::CmpIPredicate::slt;
|
||||||
|
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
}
|
||||||
|
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
||||||
|
newValue, oldValue);
|
||||||
}
|
}
|
||||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||||
nestedLoc, predicate, newIndex, oldIndex);
|
nestedLoc, predicate, newIndex, oldIndex);
|
||||||
nestedBuilder.create<linalg::YieldOp>(
|
nestedBuilder.create<linalg::YieldOp>(
|
||||||
nestedLoc, ValueRange({resultMax, resultIndex}));
|
nestedLoc, ValueRange({resultVal, resultIndex}));
|
||||||
});
|
});
|
||||||
|
|
||||||
// This cast is required to fix the shape in the case of keepDim=True
|
// This cast is required to fix the shape in the case of keepDim=True
|
||||||
Value maxValuesCast = rewriter.create<tensor::CastOp>(
|
Value valuesCast = rewriter.create<tensor::CastOp>(
|
||||||
loc, valResultType, linalgOp.getResult(0));
|
loc, valResultType, linalgOp.getResult(0));
|
||||||
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
||||||
linalgOp.getResult(1));
|
linalgOp.getResult(1));
|
||||||
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
|
rewriter.replaceOp(op, {valuesCast, idxCast});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
|
@ -574,7 +602,9 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<AtenMaxDimOp>();
|
target.addIllegalOp<AtenMaxDimOp>();
|
||||||
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
|
patterns.add<ConvertAtenMinMaxDimOp<AtenMaxDimOp>>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenMinDimOp>();
|
||||||
|
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSumOp>();
|
target.addIllegalOp<AtenSumOp>();
|
||||||
target.addIllegalOp<AtenSumDimIntListOp>();
|
target.addIllegalOp<AtenSumDimIntListOp>();
|
||||||
target.addIllegalOp<AtenProdDimIntOp>();
|
target.addIllegalOp<AtenProdDimIntOp>();
|
||||||
|
|
|
@ -6872,6 +6872,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.min.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
|
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||||
|
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
|
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
|
||||||
|
@ -10691,6 +10697,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
" return %1 : !torch.tuple<int, int>\n"
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple<int, int>) -> !torch.int\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
|
||||||
" %false = torch.constant.bool false\n"
|
" %false = torch.constant.bool false\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
|
|
@ -18,7 +18,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
"IscloseStaticModule_basic",
|
"IscloseStaticModule_basic",
|
||||||
"IscloseStaticModuleTrue_basic",
|
"IscloseStaticModuleTrue_basic"
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCHDYNAMO_XFAIL_SET = {
|
TORCHDYNAMO_XFAIL_SET = {
|
||||||
|
@ -69,6 +69,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
#ERROR: value (-56) is not equal to golden value (200)
|
#ERROR: value (-56) is not equal to golden value (200)
|
||||||
"AtenIntTensorByteDtypeModule_basic",
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
# ERROR: assert isinstance(e, FakeTensor)
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
|
|
|
@ -458,6 +458,10 @@ def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -
|
||||||
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||||
return reduced_shape, reduced_shape
|
return reduced_shape, reduced_shape
|
||||||
|
|
||||||
|
def aten〇min〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||||
|
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||||
|
return reduced_shape, reduced_shape
|
||||||
|
|
||||||
def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
||||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
|
@ -3286,6 +3290,10 @@ def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
|
||||||
def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
||||||
return aten〇max〡dtype(self_rank_dtype), torch.int64
|
return aten〇max〡dtype(self_rank_dtype), torch.int64
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
|
||||||
|
def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
||||||
|
return aten〇min〡dtype(self_rank_dtype), torch.int64
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(
|
_check_tensors_with_the_same_dtype(
|
||||||
num_of_tensors=1,
|
num_of_tensors=1,
|
||||||
|
|
|
@ -14,6 +14,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -335,6 +335,117 @@ def ReduceMaxAlongDim_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinAlongDim(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinAlongDim())
|
||||||
|
def ReduceMinAlongDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||||
|
|
||||||
|
class ReduceMinAlongDimSignedInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinAlongDimSignedInt())
|
||||||
|
def ReduceMinAlongDimSignedInt_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinAlongDimUnsignedInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.uint8, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinAlongDimUnsignedInt())
|
||||||
|
def ReduceMinAlongDimUnsignedInt_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinAlongDimNegative(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinAlongDimNegative())
|
||||||
|
def ReduceMinAlongDimNegative_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinKeepDim(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1, keepdim=True)[1]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinKeepDim())
|
||||||
|
def ReduceMinKeepDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceMinKeepDimReturnBoth(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.min(a, 1, keepdim=True)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceMinKeepDimReturnBoth())
|
||||||
|
def ReduceMinKeepDimReturnBoth_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ReduceMaxAlongDimSignedInt(torch.nn.Module):
|
class ReduceMaxAlongDimSignedInt(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue