diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td index 785f29ae4..fadeb5c8c 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td @@ -25,6 +25,8 @@ def TorchConversion_Dialect : Dialect { tensor ops being converted linalg-on-tensors and `!torch.vtensor` being converted to the builtin `tensor` type. }]; + + let hasConstantMaterializer = 1; } #endif // TORCHCONVERSION_BASE diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index c06a3dc97..32782186f 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -114,6 +114,7 @@ def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [ @@ -132,6 +133,7 @@ def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [ let assemblyFormat = [{ $operand attr-dict }]; + let hasFolder = 1; } def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [ diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index e18464f7f..9be768986 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -8,10 +8,13 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -50,3 +53,25 @@ void TorchConversionDialect::initialize() { >(); addInterfaces(); } + + +//===----------------------------------------------------------------------===// +// Constant materializer. +//===----------------------------------------------------------------------===// + +Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + if (auto integerType = type.dyn_cast()) + return builder.create(loc, value.cast()); + + if (auto floatType = type.dyn_cast()) + return builder.create(loc, value.cast()); + + if (type.isa()) { + return builder.create(loc, + value.cast()); + } + + return builder.create(loc, value, type); +} diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index 7c7c5fb94..3dc36bf53 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -37,5 +37,23 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( return success(); } +OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + +OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); + if (attr) { + return attr; + } else { + return nullptr; + } +} + #define GET_OP_CLASSES #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 822126df0..9550e2bba 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -79,6 +79,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // Finish the type conversion from `torch` types to the types of the // linalg-on-tensors backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); @@ -108,6 +109,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); diff --git a/test/Dialect/TorchConversion/canonicalize.mlir b/test/Dialect/TorchConversion/canonicalize.mlir new file mode 100644 index 000000000..bf6a9b9d3 --- /dev/null +++ b/test/Dialect/TorchConversion/canonicalize.mlir @@ -0,0 +1,39 @@ +// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: func.func @torch_c.from_i64() -> !torch.int { +// CHECK: %[[INT5:.*]] = torch.constant.int 5 +// CHECK: return %[[INT5]] : !torch.int +func.func @torch_c.from_i64() -> !torch.int { + %c5_i64 = arith.constant 5 : i64 + %0 = torch_c.from_i64 %c5_i64 + return %0 : !torch.int +} + +// CHECK-LABEL: func.func @torch_c.to_i64() -> i64 { +// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64 +// CHECK: return %[[C5_I64]] : i64 +func.func @torch_c.to_i64() -> i64 { + %int5 = torch.constant.int 5 + %0 = torch_c.to_i64 %int5 + return %0 : i64 +} + +// CHECK-LABEL: func.func @torch_c.from_i64$to_i64() -> i64 { +// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64 +// CHECK: return %[[C5_I64]] : i64 +func.func @torch_c.from_i64$to_i64() -> i64 { + %c5_i64 = arith.constant 5 : i64 + %0 = torch_c.from_i64 %c5_i64 + %1 = torch_c.to_i64 %0 + return %1 : i64 +} + +// CHECK-LABEL: func.func @torch_c.to_i64$from_i64() -> !torch.int { +// CHECK: %[[INT5:.*]] = torch.constant.int 5 +// CHECK: return %[[INT5]] : !torch.int +func.func @torch_c.to_i64$from_i64() -> !torch.int { + %int5 = torch.constant.int 5 + %0 = torch_c.to_i64 %int5 + %1 = torch_c.from_i64 %0 + return %1 : !torch.int +}