mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add aten.sort.int op
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1619/head
parent
29c8f47723
commit
4cbd3927d7
|
@ -8372,6 +8372,28 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sort.int : (int[], bool) -> ()`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTorchIntType:$self,
|
||||
Torch_BoolType:$reverse
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSortIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 0);
|
||||
}
|
||||
void AtenSortIntOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 0);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1345,6 +1345,40 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSortIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) {
|
||||
SmallVector<int64_t> listElements;
|
||||
if (!matchPattern(op.self(), m_TorchListOfConstantInts(listElements)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all input list elements must be constant ints");
|
||||
bool reverse;
|
||||
if (!matchPattern(op.reverse(), m_TorchConstantBool(&reverse)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected reverse arg to be constant bool.");
|
||||
|
||||
std::sort(listElements.begin(), listElements.end());
|
||||
if (reverse)
|
||||
std::reverse(listElements.begin(), listElements.end());
|
||||
|
||||
SmallVector<Value> sortedListElements;
|
||||
for (int64_t elem : listElements)
|
||||
sortedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(elem)));
|
||||
Value result = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
sortedListElements);
|
||||
|
||||
op.self().replaceAllUsesWith(result);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NonValueTensorLiteralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -550,6 +550,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::insert.t : (t[], int, t) -> ()")
|
||||
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
||||
emit("aten::any.bool : (bool[]) -> (bool)")
|
||||
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
||||
|
||||
# Str ops.
|
||||
emit("aten::add.str : (str, str) -> (str)")
|
||||
|
|
|
@ -3053,3 +3053,48 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
||||
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 4, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SortIntList(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
a = [1, 0, 3, 2]
|
||||
b = [0, 1, 2, 3]
|
||||
a.sort()
|
||||
return a == b
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SortIntList())
|
||||
def SortIntList_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class SortIntListReverse(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
a = [1, 0, 3, 2]
|
||||
b = [3, 2, 1, 0]
|
||||
a.sort(reverse=True)
|
||||
return a == b
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SortIntListReverse())
|
||||
def SortIntListReverse_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -1734,3 +1734,39 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.
|
|||
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: return %[[RESULT]] : !torch.list<int>
|
||||
func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.aten.sort.int %0, %false : !torch.list<int>, !torch.bool
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT2]], %[[INT1]], %[[INT0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: return %[[RESULT]] : !torch.list<int>
|
||||
func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.aten.sort.int %0, %true : !torch.list<int>, !torch.bool
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue