[MLIR][TORCH] Add aten.sort.int op

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1619/head
Vivek Khandelwal 2022-11-18 17:17:07 +05:30
parent 29c8f47723
commit 4cbd3927d7
5 changed files with 138 additions and 0 deletions

View File

@ -8372,6 +8372,28 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
}];
}
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::sort.int : (int[], bool) -> ()`";
let arguments = (ins
AnyTorchListOfTorchIntType:$self,
Torch_BoolType:$reverse
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSortIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 0);
}
void AtenSortIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 0);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1345,6 +1345,40 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) {
SmallVector<int64_t> listElements;
if (!matchPattern(op.self(), m_TorchListOfConstantInts(listElements)))
return rewriter.notifyMatchFailure(
op, "all input list elements must be constant ints");
bool reverse;
if (!matchPattern(op.reverse(), m_TorchConstantBool(&reverse)))
return rewriter.notifyMatchFailure(
op, "Expected reverse arg to be constant bool.");
std::sort(listElements.begin(), listElements.end());
if (reverse)
std::reverse(listElements.begin(), listElements.end());
SmallVector<Value> sortedListElements;
for (int64_t elem : listElements)
sortedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
op->getLoc(), rewriter.getI64IntegerAttr(elem)));
Value result = rewriter.create<Torch::PrimListConstructOp>(
op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
sortedListElements);
op.self().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===//
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//

View File

@ -550,6 +550,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::insert.t : (t[], int, t) -> ()")
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
emit("aten::any.bool : (bool[]) -> (bool)")
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
# Str ops.
emit("aten::add.str : (str, str) -> (str)")

View File

@ -3053,3 +3053,48 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 8))
# ==============================================================================
class SortIntList(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = [1, 0, 3, 2]
b = [0, 1, 2, 3]
a.sort()
return a == b
@register_test_case(module_factory=lambda: SortIntList())
def SortIntList_basic(module, tu: TestUtils):
module.forward()
class SortIntListReverse(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = [1, 0, 3, 2]
b = [3, 2, 1, 0]
a.sort(reverse=True)
return a == b
@register_test_case(module_factory=lambda: SortIntListReverse())
def SortIntListReverse_basic(module, tu: TestUtils):
module.forward()

View File

@ -1734,3 +1734,39 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[RESULT]] : !torch.list<int>
func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.aten.sort.int %0, %false : !torch.list<int>, !torch.bool
return %0 : !torch.list<int>
}
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT2]], %[[INT1]], %[[INT0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[RESULT]] : !torch.list<int>
func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.aten.sort.int %0, %true : !torch.list<int>, !torch.bool
return %0 : !torch.list<int>
}