mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add aten.sort.int op
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1619/head
parent
29c8f47723
commit
4cbd3927d7
|
@ -8372,6 +8372,28 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::sort.int : (int[], bool) -> ()`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchListOfTorchIntType:$self,
|
||||||
|
Torch_BoolType:$reverse
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenSortIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 0);
|
||||||
|
}
|
||||||
|
void AtenSortIntOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 0);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
|
def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -1345,6 +1345,40 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenSortIntOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) {
|
||||||
|
SmallVector<int64_t> listElements;
|
||||||
|
if (!matchPattern(op.self(), m_TorchListOfConstantInts(listElements)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "all input list elements must be constant ints");
|
||||||
|
bool reverse;
|
||||||
|
if (!matchPattern(op.reverse(), m_TorchConstantBool(&reverse)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected reverse arg to be constant bool.");
|
||||||
|
|
||||||
|
std::sort(listElements.begin(), listElements.end());
|
||||||
|
if (reverse)
|
||||||
|
std::reverse(listElements.begin(), listElements.end());
|
||||||
|
|
||||||
|
SmallVector<Value> sortedListElements;
|
||||||
|
for (int64_t elem : listElements)
|
||||||
|
sortedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op->getLoc(), rewriter.getI64IntegerAttr(elem)));
|
||||||
|
Value result = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||||
|
sortedListElements);
|
||||||
|
|
||||||
|
op.self().replaceAllUsesWith(result);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// NonValueTensorLiteralOp
|
// NonValueTensorLiteralOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -550,6 +550,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::insert.t : (t[], int, t) -> ()")
|
emit("aten::insert.t : (t[], int, t) -> ()")
|
||||||
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
||||||
emit("aten::any.bool : (bool[]) -> (bool)")
|
emit("aten::any.bool : (bool[]) -> (bool)")
|
||||||
|
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
||||||
|
|
||||||
# Str ops.
|
# Str ops.
|
||||||
emit("aten::add.str : (str, str) -> (str)")
|
emit("aten::add.str : (str, str) -> (str)")
|
||||||
|
|
|
@ -3053,3 +3053,48 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
||||||
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
|
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 4, 8))
|
module.forward(tu.rand(1, 1, 4, 8))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SortIntList(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
a = [1, 0, 3, 2]
|
||||||
|
b = [0, 1, 2, 3]
|
||||||
|
a.sort()
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SortIntList())
|
||||||
|
def SortIntList_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class SortIntListReverse(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
a = [1, 0, 3, 2]
|
||||||
|
b = [3, 2, 1, 0]
|
||||||
|
a.sort(reverse=True)
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SortIntListReverse())
|
||||||
|
def SortIntListReverse_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
|
@ -1734,3 +1734,39 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.
|
||||||
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
|
%2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
|
||||||
return %2 : !torch.vtensor<[],si64>
|
return %2 : !torch.vtensor<[],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<int>
|
||||||
|
func.func @torch.aten.sort.int$reverse_false() -> !torch.list<int> {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
torch.aten.sort.int %0, %false : !torch.list<int>, !torch.bool
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT2]], %[[INT1]], %[[INT0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<int>
|
||||||
|
func.func @torch.aten.sort.int$reverse_true() -> !torch.list<int> {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int0, %int3, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
torch.aten.sort.int %0, %true : !torch.list<int>, !torch.bool
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue