[Torch] support decomposition of aten.aminmax (#3513)

* unify decompisition of `aten.amax` and `aten.amin`
* support `aten.amax` with `dim=()`
pull/3515/head
Yuanqiang Liu 2024-06-29 21:44:05 +08:00 committed by GitHub
parent f9fc741eef
commit 0e71a192d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 254 additions and 106 deletions

View File

@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [
}]; }];
} }
def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenAminmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly

View File

@ -488,14 +488,18 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported"); op, "non-const integer `dim` is not supported");
} }
for (auto d : inputDims) { if (inputDims.size() == 0) {
d = toPositiveDim(d, inputTy.getRank()); dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
// Drop invalid dims } else {
if (isValidDim(d, inputTy.getRank())) { for (auto d : inputDims) {
dims.push_back(d); d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
} }
llvm::sort(dims.begin(), dims.end());
} }
llvm::sort(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape = SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), dims); getReduceOutputShape(inputTy.getShape(), dims);

View File

@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n" " return %2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n" " %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n" " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n" " return %1 : !torch.tuple<int, int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n" " %false = torch.constant.bool false\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"

View File

@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
.getValues(); .getValues();
} }
// Reduction function to calculate min along given `dim`.
static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
if (!valueType)
return nullptr;
BaseTensorType indexType =
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
return rewriter
.create<AtenMinDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.getValues();
}
// Helper for creating `aten::sub_tensor_op`. // Helper for creating `aten::sub_tensor_op`.
static Value createTensorSub(PatternRewriter &rewriter, Location loc, static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) { Type tensorType, Value lhs, Value rhs) {
@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
return out; return out;
} }
namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
/// Eg:
/// INPUT:
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
///
/// OUTPUT:
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
///
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
/// of the `aten.amax` op and create an `aten.amax.dim` op.
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
/// previous `aten.amax.dim` op.
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<int64_t, 4> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for keepDim");
Value input = op.getSelf();
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
}
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.rbegin(), dims.rend());
for (int64_t dimInt : dims) {
int64_t inputRank = inputTy.getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
// The input to the next invocation of aten.max.dim is the output of the
// previous aten.max.dim op.
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
}
rewriter.replaceOp(op, input);
return success();
}
};
} // end namespace
namespace { namespace {
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> { class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
public: public:
@ -1880,52 +1840,69 @@ public:
} // namespace } // namespace
namespace { namespace {
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
/// Eg:
/// INPUT:
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
///
/// OUTPUT:
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
///
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
/// of the `aten.amax` op and create an `aten.amax.dim` op.
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
/// previous `aten.amax.dim` op.
template <typename OpTy, typename DecompOpTy>
class DecomposeAtenAminAmaxOp : public OpRewritePattern<OpTy> {
public: public:
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern; using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(Torch::AtenAminOp op, LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
llvm::SmallVector<int64_t> dimList; Location loc = op.getLoc();
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { bool keepDim;
return rewriter.notifyMatchFailure(op, "dims not foldable constants"); if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for keepDim");
Value input = op.getSelf();
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
} }
bool keepdim; SmallVector<int64_t, 4> dims;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
if (dims.size() == 0) {
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getSizes().size()));
} }
auto loc = op.getLoc(); // For every dimension included in `dim` of the op, iterated over in
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>()); // reverse order, we create a call to aten.max.dim.
std::sort(dims.rbegin(), dims.rend());
Value reduction = op.getSelf(); for (int64_t dimInt : dims) {
auto resultTy = cast<Torch::ValueTensorType>(op.getType()); int64_t inputRank = inputTy.getSizes().size();
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType()); dimInt = toPositiveDim(dimInt, inputRank);
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes()); if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
for (auto dim : dimList) { Value dim = rewriter.create<Torch::ConstantIntOp>(
auto dimValue = rewriter.create<Torch::ConstantIntOp>( loc, rewriter.getI64IntegerAttr(dimInt));
loc, rewriter.getI64IntegerAttr(dim)); // The input to the next invocation of aten.max.dim is the output of the
reductionShape[dim] = 1; // previous aten.max.dim op.
if (!keepdim) { static_assert(std::is_same_v<OpTy, AtenAmaxOp> ||
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) std::is_same_v<OpTy, AtenAminOp>);
reductionShape[i] = reductionShape[i + 1]; if (std::is_same_v<OpTy, AtenAmaxOp>) {
reductionShape.resize(reductionShape.size() - 1); input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
} else if (std::is_same_v<OpTy, AtenAminOp>) {
input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim);
} }
reductionTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, resultTy.getOptionalDtype());
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
reduction = rewriter
.create<Torch::AtenMinDimOp>(loc, types, reduction,
dimValue, op.getKeepdim())
.getResult(0);
} }
rewriter.replaceOp(op, input);
rewriter.replaceOp(op, reduction);
return success(); return success();
} }
}; };
@ -1987,6 +1964,36 @@ public:
}; };
} // namespace } // namespace
// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp`
namespace {
class DecomposeAtenAminmaxOp : public OpRewritePattern<AtenAminmaxOp> {
public:
using OpRewritePattern<AtenAminmaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAminmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Torch::ListType listType =
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
Value dimList;
if (isa<Torch::NoneType>(op.getDim().getType())) {
dimList = rewriter.create<Torch::PrimListConstructOp>(loc, listType,
ArrayRef<Value>{});
} else {
dimList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, ArrayRef<Value>{op.getDim()});
}
auto amin = rewriter.create<AtenAminOp>(
loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim());
auto amax = rewriter.create<AtenAmaxOp>(
loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim());
rewriter.replaceOp(op, {amin, amax});
return success();
}
};
} // namespace
// Decompose `aten.bucketize` into the following op sequence: // Decompose `aten.bucketize` into the following op sequence:
// //
// def aten_bucketize(input, boundaries, out_int32, right): // def aten_bucketize(input, boundaries, out_int32, right):
@ -8598,7 +8605,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
@ -8631,10 +8637,15 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAminAmaxOp<AtenAmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAminAmaxOp<AtenAminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns); DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns); DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAminmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
@ -8707,7 +8718,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);

View File

@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLinspaceOp>(); target.addIllegalOp<AtenLinspaceOp>();
target.addIllegalOp<AtenArgmaxOp>(); target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>(); target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenAminmaxOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenAminOp>();
target.addIllegalOp<AtenSquareOp>(); target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>(); target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>(); target.addIllegalOp<AtenStdOp>();
@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNumpyTOp>(); target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>(); target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>(); target.addIllegalOp<AtenVarDimOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenVarCorrectionOp>(); target.addIllegalOp<AtenVarCorrectionOp>();
target.addIllegalOp<AtenStdDimOp>(); target.addIllegalOp<AtenStdDimOp>();
target.addIllegalOp<AtenStdCorrectionOp>(); target.addIllegalOp<AtenStdCorrectionOp>();

View File

@ -830,6 +830,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
} }
STABLEHLO_PASS_SET = { STABLEHLO_PASS_SET = {
"ReduceAminmaxSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAmaxEmptyDim_basic",
"ReduceMinAlongDimNegative_basic", "ReduceMinAlongDimNegative_basic",
"ReduceMinAlongDim_basic", "ReduceMinAlongDim_basic",
"ArgminModule_with_dim", "ArgminModule_with_dim",

View File

@ -722,6 +722,13 @@ def atenamax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa
def atenamin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: def atenamin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenaminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
if dim is None:
return [], []
else:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
def atenmeandim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: def atenmeandim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
@ -4524,6 +4531,11 @@ def atenamin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
def atenmindim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: def atenmindim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return atenmin〡dtype(self_rank_dtype), torch.int64 return atenmin〡dtype(self_rank_dtype), torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenaminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
return self_dtype, self_dtype
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype( _check_tensors_with_the_same_dtype(
num_of_tensors=1, num_of_tensors=1,

View File

@ -841,6 +841,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)")
emit( emit(
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True "aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
) )

View File

@ -1204,6 +1204,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ReduceAmaxEmptyDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.amax(a, dim=())
@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim())
def ReduceAmaxEmptyDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, high=100))
# ==============================================================================
class ReduceAmaxOutOfOrderDim(torch.nn.Module): class ReduceAmaxOutOfOrderDim(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1273,6 +1296,52 @@ def ReduceAminSingleDim_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ReduceAminmaxSingleDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.aminmax(a, dim=1)
@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim())
def ReduceAminmaxSingleDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, high=100))
# ==============================================================================
class ReduceAminmaxAllDims(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.aminmax(a)
@register_test_case(module_factory=lambda: ReduceAminmaxAllDims())
def ReduceAminmaxAllDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, high=100))
# ==============================================================================
class ReduceMinFloatModule(torch.nn.Module): class ReduceMinFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()