mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64 (#933)
* [MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64 * add unit tests for each individual fold * fix failure of NumelZeroRankModule & TestMultipleTensorAndPrimitiveTypesReturnpull/951/head
parent
189afa82c5
commit
143a7bcb76
|
@ -25,6 +25,8 @@ def TorchConversion_Dialect : Dialect {
|
||||||
tensor ops being converted linalg-on-tensors and `!torch.vtensor` being
|
tensor ops being converted linalg-on-tensors and `!torch.vtensor` being
|
||||||
converted to the builtin `tensor` type.
|
converted to the builtin `tensor` type.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasConstantMaterializer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TORCHCONVERSION_BASE
|
#endif // TORCHCONVERSION_BASE
|
||||||
|
|
|
@ -114,6 +114,7 @@ def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$operand attr-dict
|
$operand attr-dict
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
|
def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
|
||||||
|
@ -132,6 +133,7 @@ def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$operand attr-dict
|
$operand attr-dict
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
|
def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
|
||||||
|
|
|
@ -8,10 +8,13 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
@ -50,3 +53,25 @@ void TorchConversionDialect::initialize() {
|
||||||
>();
|
>();
|
||||||
addInterfaces<TorchConversionInlinerInterface>();
|
addInterfaces<TorchConversionInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Constant materializer.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
Attribute value, Type type,
|
||||||
|
Location loc) {
|
||||||
|
if (auto integerType = type.dyn_cast<Torch::IntType>())
|
||||||
|
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
|
||||||
|
|
||||||
|
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
||||||
|
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
||||||
|
|
||||||
|
if (type.isa<Torch::BoolType>()) {
|
||||||
|
return builder.create<Torch::ConstantBoolOp>(loc,
|
||||||
|
value.cast<IntegerAttr>());
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.create<arith::ConstantOp>(loc, value, type);
|
||||||
|
}
|
||||||
|
|
|
@ -37,5 +37,23 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||||
|
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||||
|
if (attr) {
|
||||||
|
return attr;
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||||
|
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||||
|
if (attr) {
|
||||||
|
return attr;
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"
|
||||||
|
|
|
@ -79,6 +79,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// linalg-on-tensors backend contract.
|
// linalg-on-tensors backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
@ -108,6 +109,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// TOSA backend contract.
|
// TOSA backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch_c.from_i64() -> !torch.int {
|
||||||
|
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: return %[[INT5]] : !torch.int
|
||||||
|
func.func @torch_c.from_i64() -> !torch.int {
|
||||||
|
%c5_i64 = arith.constant 5 : i64
|
||||||
|
%0 = torch_c.from_i64 %c5_i64
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch_c.to_i64() -> i64 {
|
||||||
|
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
|
||||||
|
// CHECK: return %[[C5_I64]] : i64
|
||||||
|
func.func @torch_c.to_i64() -> i64 {
|
||||||
|
%int5 = torch.constant.int 5
|
||||||
|
%0 = torch_c.to_i64 %int5
|
||||||
|
return %0 : i64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch_c.from_i64$to_i64() -> i64 {
|
||||||
|
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
|
||||||
|
// CHECK: return %[[C5_I64]] : i64
|
||||||
|
func.func @torch_c.from_i64$to_i64() -> i64 {
|
||||||
|
%c5_i64 = arith.constant 5 : i64
|
||||||
|
%0 = torch_c.from_i64 %c5_i64
|
||||||
|
%1 = torch_c.to_i64 %0
|
||||||
|
return %1 : i64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch_c.to_i64$from_i64() -> !torch.int {
|
||||||
|
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: return %[[INT5]] : !torch.int
|
||||||
|
func.func @torch_c.to_i64$from_i64() -> !torch.int {
|
||||||
|
%int5 = torch.constant.int 5
|
||||||
|
%0 = torch_c.to_i64 %int5
|
||||||
|
%1 = torch_c.from_i64 %0
|
||||||
|
return %1 : !torch.int
|
||||||
|
}
|
Loading…
Reference in New Issue