[MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64 (#933)

* [MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64

* add unit tests for each individual fold

* fix failure of NumelZeroRankModule & TestMultipleTensorAndPrimitiveTypesReturn
pull/951/head
Tanyo Kwok 2022-06-24 09:34:39 +08:00 committed by GitHub
parent 189afa82c5
commit 143a7bcb76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 0 deletions

View File

@ -25,6 +25,8 @@ def TorchConversion_Dialect : Dialect {
tensor ops being converted linalg-on-tensors and `!torch.vtensor` being tensor ops being converted linalg-on-tensors and `!torch.vtensor` being
converted to the builtin `tensor` type. converted to the builtin `tensor` type.
}]; }];
let hasConstantMaterializer = 1;
} }
#endif // TORCHCONVERSION_BASE #endif // TORCHCONVERSION_BASE

View File

@ -114,6 +114,7 @@ def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
let assemblyFormat = [{ let assemblyFormat = [{
$operand attr-dict $operand attr-dict
}]; }];
let hasFolder = 1;
} }
def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [ def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
@ -132,6 +133,7 @@ def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
let assemblyFormat = [{ let assemblyFormat = [{
$operand attr-dict $operand attr-dict
}]; }];
let hasFolder = 1;
} }
def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [ def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [

View File

@ -8,10 +8,13 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
@ -50,3 +53,25 @@ void TorchConversionDialect::initialize() {
>(); >();
addInterfaces<TorchConversionInlinerInterface>(); addInterfaces<TorchConversionInlinerInterface>();
} }
//===----------------------------------------------------------------------===//
// Constant materializer.
//===----------------------------------------------------------------------===//
Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>())
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
if (auto floatType = type.dyn_cast<Torch::FloatType>())
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
if (type.isa<Torch::BoolType>()) {
return builder.create<Torch::ConstantBoolOp>(loc,
value.cast<IntegerAttr>());
}
return builder.create<arith::ConstantOp>(loc, value, type);
}

View File

@ -37,5 +37,23 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
return success(); return success();
} }
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
return nullptr;
}
}
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
return nullptr;
}
}
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"

View File

@ -79,6 +79,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract. // linalg-on-tensors backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());
@ -108,6 +109,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// TOSA backend contract. // TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());

View File

@ -0,0 +1,39 @@
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: func.func @torch_c.from_i64() -> !torch.int {
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: return %[[INT5]] : !torch.int
func.func @torch_c.from_i64() -> !torch.int {
%c5_i64 = arith.constant 5 : i64
%0 = torch_c.from_i64 %c5_i64
return %0 : !torch.int
}
// CHECK-LABEL: func.func @torch_c.to_i64() -> i64 {
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
// CHECK: return %[[C5_I64]] : i64
func.func @torch_c.to_i64() -> i64 {
%int5 = torch.constant.int 5
%0 = torch_c.to_i64 %int5
return %0 : i64
}
// CHECK-LABEL: func.func @torch_c.from_i64$to_i64() -> i64 {
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
// CHECK: return %[[C5_I64]] : i64
func.func @torch_c.from_i64$to_i64() -> i64 {
%c5_i64 = arith.constant 5 : i64
%0 = torch_c.from_i64 %c5_i64
%1 = torch_c.to_i64 %0
return %1 : i64
}
// CHECK-LABEL: func.func @torch_c.to_i64$from_i64() -> !torch.int {
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: return %[[INT5]] : !torch.int
func.func @torch_c.to_i64$from_i64() -> !torch.int {
%int5 = torch.constant.int 5
%0 = torch_c.to_i64 %int5
%1 = torch_c.from_i64 %0
return %1 : !torch.int
}