mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add folder for prim.min.int (#1864)
parent
320e67ff34
commit
2f6fdb7f0b
|
@ -10894,6 +10894,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
||||||
|
|
|
@ -2348,6 +2348,25 @@ OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
|
||||||
*std::min_element(values.begin(), values.end()));
|
*std::min_element(values.begin(), values.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrimMinIntOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
||||||
|
// If both operands are the same, then the operation is an identity.
|
||||||
|
if (getA() == getB())
|
||||||
|
return getA();
|
||||||
|
|
||||||
|
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||||
|
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||||
|
if (!lhs || !rhs)
|
||||||
|
return nullptr;
|
||||||
|
// Torch semantics are that !torch.int is 64-bit signed.
|
||||||
|
return IntegerAttr::get(
|
||||||
|
lhs.getType(),
|
||||||
|
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ShapeCalculateOp
|
// ShapeCalculateOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -662,7 +662,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||||
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
|
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
|
||||||
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
|
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
|
||||||
emit("prim::min.int : (int, int) -> (int)")
|
emit("prim::min.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("prim::max.self_int : (int[]) -> (int)")
|
emit("prim::max.self_int : (int[]) -> (int)")
|
||||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("prim::RaiseException : (str, str?) -> ()")
|
emit("prim::RaiseException : (str, str?) -> ()")
|
||||||
|
|
|
@ -1270,6 +1270,25 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class PrimMinIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ops.prim.min(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: PrimMinIntModule())
|
||||||
|
def PrimMinIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NumToTensorIntModule(torch.nn.Module):
|
class NumToTensorIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
|
@ -502,6 +502,24 @@ func.func @torch.prim.max.int$constant() -> !torch.int {
|
||||||
return %0 : !torch.int
|
return %0 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.prim.min.int$identity(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
|
||||||
|
// CHECK: return %[[ARG]] : !torch.int
|
||||||
|
func.func @torch.prim.min.int$identity(%arg0: !torch.int) -> !torch.int {
|
||||||
|
%0 = torch.prim.min.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.prim.min.int$constant() -> !torch.int {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: return %[[INT1]] : !torch.int
|
||||||
|
func.func @torch.prim.min.int$constant() -> !torch.int {
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.prim.min.int %int-1, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int {
|
// CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int {
|
||||||
// CHECK: %[[M1:.*]] = torch.constant.int -1
|
// CHECK: %[[M1:.*]] = torch.constant.int -1
|
||||||
// CHECK: return %[[M1]] : !torch.int
|
// CHECK: return %[[M1]] : !torch.int
|
||||||
|
|
Loading…
Reference in New Issue