mirror of https://github.com/llvm/torch-mlir
[Torch] support decomposition of aten.aminmax (#3513)
* unify decompisition of `aten.amax` and `aten.amin` * support `aten.amax` with `dim=()`pull/3515/head
parent
f9fc741eef
commit
0e71a192d8
|
@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalIntType:$dim,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$min,
|
||||
AnyTorchOptionalTensorType:$max
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 2);
|
||||
}
|
||||
void AtenAminmaxOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -488,6 +488,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const integer `dim` is not supported");
|
||||
}
|
||||
if (inputDims.size() == 0) {
|
||||
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
} else {
|
||||
for (auto d : inputDims) {
|
||||
d = toPositiveDim(d, inputTy.getRank());
|
||||
// Drop invalid dims
|
||||
|
@ -496,6 +499,7 @@ public:
|
|||
}
|
||||
}
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
}
|
||||
SmallVector<int64_t> reduceResultShape =
|
||||
getReduceOutputShape(inputTy.getShape(), dims);
|
||||
|
||||
|
|
|
@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
|
||||
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
|
@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
|
|
@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
.getValues();
|
||||
}
|
||||
|
||||
// Reduction function to calculate min along given `dim`.
|
||||
static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
|
||||
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
|
||||
if (!valueType)
|
||||
return nullptr;
|
||||
BaseTensorType indexType =
|
||||
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
|
||||
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::ArrayRef(valueType.getSizes()),
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
|
||||
return rewriter
|
||||
.create<AtenMinDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
||||
.getValues();
|
||||
}
|
||||
|
||||
// Helper for creating `aten::sub_tensor_op`.
|
||||
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||
Type tensorType, Value lhs, Value rhs) {
|
||||
|
@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
|||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||
/// number of dimensions across which the max needs to be computed.
|
||||
/// Eg:
|
||||
/// INPUT:
|
||||
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
||||
///
|
||||
/// OUTPUT:
|
||||
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
||||
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
||||
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
||||
///
|
||||
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
||||
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
||||
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
||||
/// previous `aten.amax.dim` op.
|
||||
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenAmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<int64_t, 4> dims;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
||||
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
|
||||
bool keepDim;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a constant boolean value for keepDim");
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputTy || !inputTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
}
|
||||
// For every dimension included in `dim` of the op, iterated over in
|
||||
// reverse order, we create a call to aten.max.dim.
|
||||
std::sort(dims.rbegin(), dims.rend());
|
||||
for (int64_t dimInt : dims) {
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||
// The input to the next invocation of aten.max.dim is the output of the
|
||||
// previous aten.max.dim op.
|
||||
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
||||
}
|
||||
rewriter.replaceOp(op, input);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
||||
public:
|
||||
|
@ -1880,52 +1840,69 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
|
||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||
/// number of dimensions across which the max needs to be computed.
|
||||
/// Eg:
|
||||
/// INPUT:
|
||||
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
||||
///
|
||||
/// OUTPUT:
|
||||
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
||||
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
||||
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
||||
///
|
||||
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
||||
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
||||
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
||||
/// previous `aten.amax.dim` op.
|
||||
template <typename OpTy, typename DecompOpTy>
|
||||
class DecomposeAtenAminAmaxOp : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::SmallVector<int64_t> dimList;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
||||
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
|
||||
Location loc = op.getLoc();
|
||||
bool keepDim;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a constant boolean value for keepDim");
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputTy || !inputTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
}
|
||||
|
||||
bool keepdim;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
|
||||
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
|
||||
SmallVector<int64_t, 4> dims;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
if (dims.size() == 0) {
|
||||
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getSizes().size()));
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());
|
||||
|
||||
Value reduction = op.getSelf();
|
||||
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
|
||||
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
|
||||
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());
|
||||
|
||||
for (auto dim : dimList) {
|
||||
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dim));
|
||||
reductionShape[dim] = 1;
|
||||
if (!keepdim) {
|
||||
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
|
||||
reductionShape[i] = reductionShape[i + 1];
|
||||
reductionShape.resize(reductionShape.size() - 1);
|
||||
// For every dimension included in `dim` of the op, iterated over in
|
||||
// reverse order, we create a call to aten.max.dim.
|
||||
std::sort(dims.rbegin(), dims.rend());
|
||||
for (int64_t dimInt : dims) {
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||
// The input to the next invocation of aten.max.dim is the output of the
|
||||
// previous aten.max.dim op.
|
||||
static_assert(std::is_same_v<OpTy, AtenAmaxOp> ||
|
||||
std::is_same_v<OpTy, AtenAminOp>);
|
||||
if (std::is_same_v<OpTy, AtenAmaxOp>) {
|
||||
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
||||
} else if (std::is_same_v<OpTy, AtenAminOp>) {
|
||||
input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
||||
}
|
||||
|
||||
reductionTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
reductionShape, resultTy.getOptionalDtype());
|
||||
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
||||
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
||||
|
||||
reduction = rewriter
|
||||
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
||||
dimValue, op.getKeepdim())
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduction);
|
||||
rewriter.replaceOp(op, input);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1987,6 +1964,36 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp`
|
||||
namespace {
|
||||
class DecomposeAtenAminmaxOp : public OpRewritePattern<AtenAminmaxOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenAminmaxOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenAminmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Torch::ListType listType =
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
|
||||
Value dimList;
|
||||
if (isa<Torch::NoneType>(op.getDim().getType())) {
|
||||
dimList = rewriter.create<Torch::PrimListConstructOp>(loc, listType,
|
||||
ArrayRef<Value>{});
|
||||
} else {
|
||||
dimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, listType, ArrayRef<Value>{op.getDim()});
|
||||
}
|
||||
|
||||
auto amin = rewriter.create<AtenAminOp>(
|
||||
loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim());
|
||||
auto amax = rewriter.create<AtenAmaxOp>(
|
||||
loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim());
|
||||
rewriter.replaceOp(op, {amin, amax});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `aten.bucketize` into the following op sequence:
|
||||
//
|
||||
// def aten_bucketize(input, boundaries, out_int32, right):
|
||||
|
@ -8598,7 +8605,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||
|
@ -8631,10 +8637,15 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenAminAmaxOp<AtenAmaxOp, AtenMaxDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenAminAmaxOp<AtenAminOp, AtenMinDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAminmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
||||
|
@ -8707,7 +8718,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
|
|
|
@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenLinspaceOp>();
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
target.addIllegalOp<AtenArgminOp>();
|
||||
target.addIllegalOp<AtenAminmaxOp>();
|
||||
target.addIllegalOp<AtenAmaxOp>();
|
||||
target.addIllegalOp<AtenAminOp>();
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
target.addIllegalOp<AtenVarOp>();
|
||||
target.addIllegalOp<AtenStdOp>();
|
||||
|
@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
target.addIllegalOp<AtenAmaxOp>();
|
||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||
target.addIllegalOp<AtenStdDimOp>();
|
||||
target.addIllegalOp<AtenStdCorrectionOp>();
|
||||
|
|
|
@ -830,6 +830,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"ReduceAminmaxSingleDim_basic",
|
||||
"ReduceAminmaxAllDims_basic",
|
||||
"ReduceAmaxEmptyDim_basic",
|
||||
"ReduceMinAlongDimNegative_basic",
|
||||
"ReduceMinAlongDim_basic",
|
||||
"ArgminModule_with_dim",
|
||||
|
|
|
@ -722,6 +722,13 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa
|
|||
def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||
if dim is None:
|
||||
return [], []
|
||||
else:
|
||||
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
return reduced_shape, reduced_shape
|
||||
|
||||
def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||
|
||||
|
@ -4524,6 +4531,11 @@ def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
|
|||
def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
||||
return aten〇min〡dtype(self_rank_dtype), torch.int64
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇aminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype, self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
|
|
|
@ -841,6 +841,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)")
|
||||
emit(
|
||||
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
|
||||
)
|
||||
|
|
|
@ -1204,6 +1204,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceAmaxEmptyDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.amax(a, dim=())
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim())
|
||||
def ReduceAmaxEmptyDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceAmaxOutOfOrderDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1273,6 +1296,52 @@ def ReduceAminSingleDim_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceAminmaxSingleDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.aminmax(a, dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim())
|
||||
def ReduceAminmaxSingleDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceAminmaxAllDims(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.aminmax(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceAminmaxAllDims())
|
||||
def ReduceAminmaxAllDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMinFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue