[Torch Dialect] add more scalar op folders (#2265)

pull/2273/head snapshot-20230629.884
Yuanqiang Liu 2023-06-29 10:37:13 +08:00 committed by GitHub
parent 82819350e1
commit 449cfb8375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 3 deletions

View File

@ -10554,6 +10554,7 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
@ -10603,6 +10604,7 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [ def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
@ -10651,6 +10653,7 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [ def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [

View File

@ -2319,6 +2319,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenMulFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) {
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a * b; });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubFloatOp // AtenSubFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2381,6 +2390,18 @@ OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
[](double a, double b) -> double { return a / b; }); [](double a, double b) -> double { return a / b; });
} }
//===----------------------------------------------------------------------===//
// AtenAddFloatIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a + b; });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenPowIntFloatOp // AtenPowIntFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2421,6 +2442,21 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenNegFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
if (!value) {
return nullptr;
}
return getF64FloatAttr(getContext(), -value.getValue().convertToDouble());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSqrtIntOp // AtenSqrtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -633,11 +633,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)") emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True)
emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
emit("aten::mul.float : (float, float) -> (float)") emit("aten::mul.float : (float, float) -> (float)", has_folder=True)
emit("aten::div.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True)
emit("aten::neg.float : (float) -> (float)") emit("aten::neg.float : (float) -> (float)", has_folder=True)
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True) emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True) emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)

View File

@ -1036,6 +1036,16 @@ func.func @torch.aten.add.int() -> !torch.int {
return %ret : !torch.int return %ret : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.add.float_int() -> !torch.float {
// CHECK: %[[CST9:.*]] = torch.constant.float 9.000000e+00
// CHECK: return %[[CST9]] : !torch.float
func.func @torch.aten.add.float_int() -> !torch.float {
%cst4 = torch.constant.float 4.0
%cst5 = torch.constant.int 5
%ret = torch.aten.add.float_int %cst4, %cst5: !torch.float, !torch.int -> !torch.float
return %ret : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int { // CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int {
// CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: return %[[CST1]] : !torch.int // CHECK: return %[[CST1]] : !torch.int
@ -1056,6 +1066,25 @@ func.func @torch.aten.mul.int() -> !torch.int {
return %ret : !torch.int return %ret : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float
func.func @torch.aten.mul.float() -> !torch.float {
%cst6 = torch.constant.float 6.0
%cst5 = torch.constant.float 5.0
%ret = torch.aten.mul.float %cst6, %cst5: !torch.float, !torch.float -> !torch.float
return %ret : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float {
// CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00
// CHECK: return %[[CST_6]] : !torch.float
func.func @torch.aten.neg.float() -> !torch.float {
%cst6 = torch.constant.float 6.0
%ret = torch.aten.neg.float %cst6: !torch.float -> !torch.float
return %ret : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int { // CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: return %[[CST0]] : !torch.int // CHECK: return %[[CST0]] : !torch.int