Add folder for ToF64Op and FromF64Op (#1257)

pull/1260/head
武家伟 2022-08-22 09:49:39 +08:00 committed by GitHub
parent ba17a4d6c0
commit 99fb4c8637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 0 deletions

View File

@ -154,6 +154,7 @@ def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}
def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
@ -172,6 +173,7 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}
def TorchConversion_I64ToGeneratorOp : TorchConversion_Op<"i64_to_generator", [

View File

@ -97,5 +97,31 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
}
}
//===----------------------------------------------------------------------===//
// ToF64Op
//===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {
return nullptr;
}
}
//===----------------------------------------------------------------------===//
// FromF64Op
//===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {
return nullptr;
}
}
#define GET_OP_CLASSES
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"

View File

@ -37,3 +37,41 @@ func.func @torch_c.to_i64$from_i64() -> !torch.int {
%1 = torch_c.from_i64 %0
return %1 : !torch.int
}
// CHECK-LABEL: func.func @torch_c.from_f64() -> !torch.float {
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
// CHECK: return %[[FLOAT5]] : !torch.float
func.func @torch_c.from_f64() -> !torch.float {
%c5_f64 = arith.constant 5.000000e+00 : f64
%0 = torch_c.from_f64 %c5_f64
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch_c.to_f64() -> f64 {
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
// CHECK: return %[[C5_f64]] : f64
func.func @torch_c.to_f64() -> f64 {
%float5 = torch.constant.float 5.000000e+00
%0 = torch_c.to_f64 %float5
return %0 : f64
}
// CHECK-LABEL: func.func @torch_c.from_f64$to_f64() -> f64 {
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
// CHECK: return %[[C5_f64]] : f64
func.func @torch_c.from_f64$to_f64() -> f64 {
%c5_f64 = arith.constant 5.000000e+00 : f64
%0 = torch_c.from_f64 %c5_f64
%1 = torch_c.to_f64 %0
return %1 : f64
}
// CHECK-LABEL: func.func @torch_c.to_f64$from_f64() -> !torch.float {
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
// CHECK: return %[[FLOAT5]] : !torch.float
func.func @torch_c.to_f64$from_f64() -> !torch.float {
%float5 = torch.constant.float 5.000000e+00
%0 = torch_c.to_f64 %float5
%1 = torch_c.from_f64 %0
return %1 : !torch.float
}