mirror of https://github.com/llvm/torch-mlir
[Torch] support decomposition of aten.aminmax (#3513)
* unify decompisition of `aten.amax` and `aten.amin` * support `aten.amax` with `dim=()`pull/3515/head
parent
f9fc741eef
commit
0e71a192d8
|
@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchOptionalIntType:$dim,
|
||||||
|
Torch_BoolType:$keepdim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$min,
|
||||||
|
AnyTorchOptionalTensorType:$max
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 2);
|
||||||
|
}
|
||||||
|
void AtenAminmaxOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 2);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
|
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
|
|
|
@ -488,14 +488,18 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "non-const integer `dim` is not supported");
|
op, "non-const integer `dim` is not supported");
|
||||||
}
|
}
|
||||||
for (auto d : inputDims) {
|
if (inputDims.size() == 0) {
|
||||||
d = toPositiveDim(d, inputTy.getRank());
|
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||||
// Drop invalid dims
|
} else {
|
||||||
if (isValidDim(d, inputTy.getRank())) {
|
for (auto d : inputDims) {
|
||||||
dims.push_back(d);
|
d = toPositiveDim(d, inputTy.getRank());
|
||||||
|
// Drop invalid dims
|
||||||
|
if (isValidDim(d, inputTy.getRank())) {
|
||||||
|
dims.push_back(d);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
llvm::sort(dims.begin(), dims.end());
|
||||||
}
|
}
|
||||||
llvm::sort(dims.begin(), dims.end());
|
|
||||||
SmallVector<int64_t> reduceResultShape =
|
SmallVector<int64_t> reduceResultShape =
|
||||||
getReduceOutputShape(inputTy.getShape(), dims);
|
getReduceOutputShape(inputTy.getShape(), dims);
|
||||||
|
|
||||||
|
|
|
@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||||
" return %2 : !torch.list<int>\n"
|
" return %2 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
|
||||||
|
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
|
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
|
||||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||||
|
@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
" return %1 : !torch.tuple<int, int>\n"
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
|
||||||
" %false = torch.constant.bool false\n"
|
" %false = torch.constant.bool false\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
|
|
@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
.getValues();
|
.getValues();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reduction function to calculate min along given `dim`.
|
||||||
|
static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
|
Operation *op, Value input, Value dim,
|
||||||
|
bool keepDim) {
|
||||||
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||||
|
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
|
||||||
|
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
|
||||||
|
if (!valueType)
|
||||||
|
return nullptr;
|
||||||
|
BaseTensorType indexType =
|
||||||
|
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
|
||||||
|
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||||
|
: llvm::ArrayRef(valueType.getSizes()),
|
||||||
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
|
||||||
|
return rewriter
|
||||||
|
.create<AtenMinDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
||||||
|
.getValues();
|
||||||
|
}
|
||||||
|
|
||||||
// Helper for creating `aten::sub_tensor_op`.
|
// Helper for creating `aten::sub_tensor_op`.
|
||||||
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||||
Type tensorType, Value lhs, Value rhs) {
|
Type tensorType, Value lhs, Value rhs) {
|
||||||
|
@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
|
||||||
/// number of dimensions across which the max needs to be computed.
|
|
||||||
/// Eg:
|
|
||||||
/// INPUT:
|
|
||||||
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
|
||||||
///
|
|
||||||
/// OUTPUT:
|
|
||||||
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
|
||||||
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
|
||||||
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
|
||||||
///
|
|
||||||
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
|
||||||
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
|
||||||
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
|
||||||
/// previous `aten.amax.dim` op.
|
|
||||||
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(AtenAmaxOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
SmallVector<int64_t, 4> dims;
|
|
||||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
|
||||||
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const dim parameter unsupported");
|
|
||||||
|
|
||||||
bool keepDim;
|
|
||||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Expected a constant boolean value for keepDim");
|
|
||||||
|
|
||||||
Value input = op.getSelf();
|
|
||||||
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
|
||||||
if (!inputTy || !inputTy.hasSizes()) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"Expected input type having sizes");
|
|
||||||
}
|
|
||||||
// For every dimension included in `dim` of the op, iterated over in
|
|
||||||
// reverse order, we create a call to aten.max.dim.
|
|
||||||
std::sort(dims.rbegin(), dims.rend());
|
|
||||||
for (int64_t dimInt : dims) {
|
|
||||||
int64_t inputRank = inputTy.getSizes().size();
|
|
||||||
dimInt = toPositiveDim(dimInt, inputRank);
|
|
||||||
if (!isValidDim(dimInt, inputRank))
|
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
|
||||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
|
||||||
// The input to the next invocation of aten.max.dim is the output of the
|
|
||||||
// previous aten.max.dim op.
|
|
||||||
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, input);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // end namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -1880,52 +1840,69 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
|
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||||
|
/// number of dimensions across which the max needs to be computed.
|
||||||
|
/// Eg:
|
||||||
|
/// INPUT:
|
||||||
|
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
||||||
|
///
|
||||||
|
/// OUTPUT:
|
||||||
|
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
||||||
|
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
||||||
|
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
||||||
|
///
|
||||||
|
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
||||||
|
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
||||||
|
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
||||||
|
/// previous `aten.amax.dim` op.
|
||||||
|
template <typename OpTy, typename DecompOpTy>
|
||||||
|
class DecomposeAtenAminAmaxOp : public OpRewritePattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
llvm::SmallVector<int64_t> dimList;
|
Location loc = op.getLoc();
|
||||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
bool keepDim;
|
||||||
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected a constant boolean value for keepDim");
|
||||||
|
|
||||||
|
Value input = op.getSelf();
|
||||||
|
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||||
|
if (!inputTy || !inputTy.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Expected input type having sizes");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool keepdim;
|
SmallVector<int64_t, 4> dims;
|
||||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
||||||
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const dim parameter unsupported");
|
||||||
|
if (dims.size() == 0) {
|
||||||
|
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getSizes().size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
// For every dimension included in `dim` of the op, iterated over in
|
||||||
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());
|
// reverse order, we create a call to aten.max.dim.
|
||||||
|
std::sort(dims.rbegin(), dims.rend());
|
||||||
Value reduction = op.getSelf();
|
for (int64_t dimInt : dims) {
|
||||||
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
|
int64_t inputRank = inputTy.getSizes().size();
|
||||||
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
|
dimInt = toPositiveDim(dimInt, inputRank);
|
||||||
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());
|
if (!isValidDim(dimInt, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
for (auto dim : dimList) {
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
|
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||||
loc, rewriter.getI64IntegerAttr(dim));
|
// The input to the next invocation of aten.max.dim is the output of the
|
||||||
reductionShape[dim] = 1;
|
// previous aten.max.dim op.
|
||||||
if (!keepdim) {
|
static_assert(std::is_same_v<OpTy, AtenAmaxOp> ||
|
||||||
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
|
std::is_same_v<OpTy, AtenAminOp>);
|
||||||
reductionShape[i] = reductionShape[i + 1];
|
if (std::is_same_v<OpTy, AtenAmaxOp>) {
|
||||||
reductionShape.resize(reductionShape.size() - 1);
|
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
||||||
|
} else if (std::is_same_v<OpTy, AtenAminOp>) {
|
||||||
|
input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
reductionTy = rewriter.getType<Torch::ValueTensorType>(
|
|
||||||
reductionShape, resultTy.getOptionalDtype());
|
|
||||||
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
|
||||||
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
|
||||||
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
|
||||||
|
|
||||||
reduction = rewriter
|
|
||||||
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
|
||||||
dimValue, op.getKeepdim())
|
|
||||||
.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
rewriter.replaceOp(op, input);
|
||||||
rewriter.replaceOp(op, reduction);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1987,6 +1964,36 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp`
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenAminmaxOp : public OpRewritePattern<AtenAminmaxOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenAminmaxOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenAminmaxOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
Torch::ListType listType =
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
|
||||||
|
Value dimList;
|
||||||
|
if (isa<Torch::NoneType>(op.getDim().getType())) {
|
||||||
|
dimList = rewriter.create<Torch::PrimListConstructOp>(loc, listType,
|
||||||
|
ArrayRef<Value>{});
|
||||||
|
} else {
|
||||||
|
dimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, listType, ArrayRef<Value>{op.getDim()});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto amin = rewriter.create<AtenAminOp>(
|
||||||
|
loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim());
|
||||||
|
auto amax = rewriter.create<AtenAmaxOp>(
|
||||||
|
loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim());
|
||||||
|
rewriter.replaceOp(op, {amin, amax});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose `aten.bucketize` into the following op sequence:
|
// Decompose `aten.bucketize` into the following op sequence:
|
||||||
//
|
//
|
||||||
// def aten_bucketize(input, boundaries, out_int32, right):
|
// def aten_bucketize(input, boundaries, out_int32, right):
|
||||||
|
@ -8598,7 +8605,6 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
|
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||||
|
@ -8631,10 +8637,15 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<
|
||||||
|
DecomposeAtenAminAmaxOp<AtenAmaxOp, AtenMaxDimOp>>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<
|
||||||
|
DecomposeAtenAminAmaxOp<AtenAminOp, AtenMinDimOp>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<
|
addPatternIfTargetOpIsIllegal<
|
||||||
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAminmaxOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
||||||
|
@ -8707,7 +8718,6 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||||
|
|
|
@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenLinspaceOp>();
|
target.addIllegalOp<AtenLinspaceOp>();
|
||||||
target.addIllegalOp<AtenArgmaxOp>();
|
target.addIllegalOp<AtenArgmaxOp>();
|
||||||
target.addIllegalOp<AtenArgminOp>();
|
target.addIllegalOp<AtenArgminOp>();
|
||||||
|
target.addIllegalOp<AtenAminmaxOp>();
|
||||||
|
target.addIllegalOp<AtenAmaxOp>();
|
||||||
|
target.addIllegalOp<AtenAminOp>();
|
||||||
target.addIllegalOp<AtenSquareOp>();
|
target.addIllegalOp<AtenSquareOp>();
|
||||||
target.addIllegalOp<AtenVarOp>();
|
target.addIllegalOp<AtenVarOp>();
|
||||||
target.addIllegalOp<AtenStdOp>();
|
target.addIllegalOp<AtenStdOp>();
|
||||||
|
@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenNumpyTOp>();
|
target.addIllegalOp<AtenNumpyTOp>();
|
||||||
target.addIllegalOp<AtenSelectScatterOp>();
|
target.addIllegalOp<AtenSelectScatterOp>();
|
||||||
target.addIllegalOp<AtenVarDimOp>();
|
target.addIllegalOp<AtenVarDimOp>();
|
||||||
target.addIllegalOp<AtenAmaxOp>();
|
|
||||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||||
target.addIllegalOp<AtenStdDimOp>();
|
target.addIllegalOp<AtenStdDimOp>();
|
||||||
target.addIllegalOp<AtenStdCorrectionOp>();
|
target.addIllegalOp<AtenStdCorrectionOp>();
|
||||||
|
|
|
@ -830,6 +830,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
"ReduceAminmaxSingleDim_basic",
|
||||||
|
"ReduceAminmaxAllDims_basic",
|
||||||
|
"ReduceAmaxEmptyDim_basic",
|
||||||
"ReduceMinAlongDimNegative_basic",
|
"ReduceMinAlongDimNegative_basic",
|
||||||
"ReduceMinAlongDim_basic",
|
"ReduceMinAlongDim_basic",
|
||||||
"ArgminModule_with_dim",
|
"ArgminModule_with_dim",
|
||||||
|
|
|
@ -722,6 +722,13 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa
|
||||||
def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
|
||||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
|
def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
|
||||||
|
if dim is None:
|
||||||
|
return [], []
|
||||||
|
else:
|
||||||
|
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
|
||||||
|
return reduced_shape, reduced_shape
|
||||||
|
|
||||||
def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
|
@ -4524,6 +4531,11 @@ def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
|
||||||
def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
|
||||||
return aten〇min〡dtype(self_rank_dtype), torch.int64
|
return aten〇min〡dtype(self_rank_dtype), torch.int64
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
|
def aten〇aminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype, self_dtype
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(
|
_check_tensors_with_the_same_dtype(
|
||||||
num_of_tensors=1,
|
num_of_tensors=1,
|
||||||
|
|
|
@ -841,6 +841,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
|
||||||
|
emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
|
"aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -1204,6 +1204,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceAmaxEmptyDim(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.amax(a, dim=())
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim())
|
||||||
|
def ReduceAmaxEmptyDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ReduceAmaxOutOfOrderDim(torch.nn.Module):
|
class ReduceAmaxOutOfOrderDim(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1273,6 +1296,52 @@ def ReduceAminSingleDim_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceAminmaxSingleDim(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.aminmax(a, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim())
|
||||||
|
def ReduceAminmaxSingleDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceAminmaxAllDims(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.aminmax(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceAminmaxAllDims())
|
||||||
|
def ReduceAminmaxAllDims_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ReduceMinFloatModule(torch.nn.Module):
|
class ReduceMinFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue