[Torch Dialect] add folder for prim.min.int (#1864)

pull/1875/head
Yuanqiang Liu 2023-02-11 05:58:15 +08:00 committed by GitHub
parent 320e67ff34
commit 2f6fdb7f0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 1 deletions

View File

@ -10894,6 +10894,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [

View File

@ -2348,6 +2348,25 @@ OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
*std::min_element(values.begin(), values.end())); *std::min_element(values.begin(), values.end()));
} }
//===----------------------------------------------------------------------===//
// PrimMinIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity.
if (getA() == getB())
return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
return IntegerAttr::get(
lhs.getType(),
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ShapeCalculateOp // ShapeCalculateOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -662,7 +662,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)") emit("prim::min.int : (int, int) -> (int)", has_folder=True)
emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)", has_folder=True) emit("prim::max.int : (int, int) -> (int)", has_folder=True)
emit("prim::RaiseException : (str, str?) -> ()") emit("prim::RaiseException : (str, str?) -> ()")

View File

@ -1270,6 +1270,25 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class PrimMinIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.prim.min(1, -1)
@register_test_case(module_factory=lambda: PrimMinIntModule())
def PrimMinIntModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class NumToTensorIntModule(torch.nn.Module): class NumToTensorIntModule(torch.nn.Module):

View File

@ -502,6 +502,24 @@ func.func @torch.prim.max.int$constant() -> !torch.int {
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func.func @torch.prim.min.int$identity(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: return %[[ARG]] : !torch.int
func.func @torch.prim.min.int$identity(%arg0: !torch.int) -> !torch.int {
%0 = torch.prim.min.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func.func @torch.prim.min.int$constant() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int -1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.prim.min.int$constant() -> !torch.int {
%int-1 = torch.constant.int -1
%int3 = torch.constant.int 3
%0 = torch.prim.min.int %int-1, %int3 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int { // CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int {
// CHECK: %[[M1:.*]] = torch.constant.int -1 // CHECK: %[[M1:.*]] = torch.constant.int -1
// CHECK: return %[[M1]] : !torch.int // CHECK: return %[[M1]] : !torch.int