mirror of https://github.com/llvm/torch-mlir
parent
82819350e1
commit
449cfb8375
|
@ -10554,6 +10554,7 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
|
||||
|
@ -10603,6 +10604,7 @@ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDivFloatOp : Torch_Op<"aten.div.float", [
|
||||
|
@ -10651,6 +10653,7 @@ def Torch_AtenNegFloatOp : Torch_Op<"aten.neg.float", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenEqFloatOp : Torch_Op<"aten.eq.float", [
|
||||
|
|
|
@ -2319,6 +2319,15 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(), [](double a, double b) { return a * b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2381,6 +2390,18 @@ OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
|
|||
[](double a, double b) -> double { return a / b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddFloatIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(), [](double a, double b) { return a + b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenPowIntFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2421,6 +2442,21 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNegFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
return getF64FloatAttr(getContext(), -value.getValue().convertToDouble());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSqrtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -633,11 +633,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
emit("aten::add.float_int : (float, int) -> (float)", has_folder=True)
|
||||
emit("aten::sub.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::mul.float : (float, float) -> (float)")
|
||||
emit("aten::mul.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::div.float : (float, float) -> (float)", has_folder=True)
|
||||
emit("aten::neg.float : (float) -> (float)")
|
||||
emit("aten::neg.float : (float) -> (float)", has_folder=True)
|
||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
||||
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -1036,6 +1036,16 @@ func.func @torch.aten.add.int() -> !torch.int {
|
|||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.add.float_int() -> !torch.float {
|
||||
// CHECK: %[[CST9:.*]] = torch.constant.float 9.000000e+00
|
||||
// CHECK: return %[[CST9]] : !torch.float
|
||||
func.func @torch.aten.add.float_int() -> !torch.float {
|
||||
%cst4 = torch.constant.float 4.0
|
||||
%cst5 = torch.constant.int 5
|
||||
%ret = torch.aten.add.float_int %cst4, %cst5: !torch.float, !torch.int -> !torch.float
|
||||
return %ret : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sub.int() -> !torch.int {
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[CST1]] : !torch.int
|
||||
|
@ -1056,6 +1066,25 @@ func.func @torch.aten.mul.int() -> !torch.int {
|
|||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
|
||||
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
|
||||
// CHECK: return %[[CST30]] : !torch.float
|
||||
func.func @torch.aten.mul.float() -> !torch.float {
|
||||
%cst6 = torch.constant.float 6.0
|
||||
%cst5 = torch.constant.float 5.0
|
||||
%ret = torch.aten.mul.float %cst6, %cst5: !torch.float, !torch.float -> !torch.float
|
||||
return %ret : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float {
|
||||
// CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00
|
||||
// CHECK: return %[[CST_6]] : !torch.float
|
||||
func.func @torch.aten.neg.float() -> !torch.float {
|
||||
%cst6 = torch.constant.float 6.0
|
||||
%ret = torch.aten.neg.float %cst6: !torch.float -> !torch.float
|
||||
return %ret : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.int$with_zero() -> !torch.int {
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: return %[[CST0]] : !torch.int
|
||||
|
|
Loading…
Reference in New Issue