diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d5e64b4f..d6ada3cf4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -128,6 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) "suffix = '${PYTHON_MODULE_SUFFIX}', " "extension = '${PYTHON_MODULE_EXTENSION}") + # Include the iree-dialects external project. + set(LLVM_EXTERNAL_PROJECTS "iree-dialects") + set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects") + # LLVM configuration. message(STATUS "*** ADDING LLVM ***") add_subdirectory( @@ -177,6 +181,8 @@ include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) set(NPCOMP_TABLEGEN_ARGS "") diff --git a/external/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp b/external/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp index fd6983457..ce82241c0 100644 --- a/external/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp +++ b/external/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp @@ -22,7 +22,7 @@ using namespace mlir::iree; void IREEDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST -#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc" +#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc" >(); addOperations< #define GET_OP_LIST diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index e71da3fe3..f56a29171 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -104,6 +104,14 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()"; } +def ConvertTorchToIREE : Pass<"convert-torch-to-iree", "FuncOp"> { + let summary = "Convert recognized Torch ops to IREE ops"; + let description = [{ + TODO + }]; + let constructor = "mlir::NPCOMP::createConvertTorchToIREEPass()"; +} + //===----------------------------------------------------------------------===// // Basicpy conversions //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Conversion/TorchToIREE/TorchToIREE.h b/include/npcomp/Conversion/TorchToIREE/TorchToIREE.h new file mode 100644 index 000000000..da6ba2880 --- /dev/null +++ b/include/npcomp/Conversion/TorchToIREE/TorchToIREE.h @@ -0,0 +1,21 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H +#define NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace NPCOMP { +std::unique_ptr> createConvertTorchToIREEPass(); +} +} // namespace mlir + +#endif // NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H diff --git a/include/npcomp/Dialect/CMakeLists.txt b/include/npcomp/Dialect/CMakeLists.txt index 2068c869e..e48cf7467 100644 --- a/include/npcomp/Dialect/CMakeLists.txt +++ b/include/npcomp/Dialect/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(Numpy) add_subdirectory(Refback) add_subdirectory(Refbackrt) add_subdirectory(Torch) +add_subdirectory(TorchConversion) diff --git a/include/npcomp/Dialect/Torch/IR/TorchBase.td b/include/npcomp/Dialect/Torch/IR/TorchBase.td index 2f07c960a..8ff9b4786 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchBase.td +++ b/include/npcomp/Dialect/Torch/IR/TorchBase.td @@ -19,7 +19,18 @@ def Torch_Dialect : Dialect { This dialect maintains a fairly isomorphic representation with TorchScript. - TODO: Add more detail here. + This dialect also provides transforms that lower it to the + "Torch backend contract", which is an IR form that we present to + later conversions, such as conversion to the npcomp backend contract. + The Torch backend contract significantly simplifies the IR representation + and puts it in a form easier for later lowering to work on. Specifically: + - The TorchScript object graph has been flattened to a list of globals (see + the GlobalizeObjectGraph tranformation). + - Most of the operations have been changed to operate on value-semantic + tensors (see MaximizeValueSemantics) + - The number of op variants have been reduced (see ReduceOpVariants) + - Tensor sizes have been analyzed and static ranks inferred where possible + and propagated throughout the program. }]; let hasRegionArgAttrVerify = 1; diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.h b/include/npcomp/Dialect/Torch/IR/TorchOps.h index 5919b5958..e73d71f5e 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.h +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.h @@ -87,14 +87,15 @@ struct torch_list_construct_op_binder { : bind_values(bvs) {} bool match(Operation *op) { - if (auto constantNums = dyn_cast(op)) { - for (Value value : constantNums.elements()) { - int64_t num; - if (matchPattern(value, m_TorchConstantInt(&num))) - bind_values.push_back(num); - else - return false; - } + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.elements()) { + int64_t num; + if (matchPattern(value, m_TorchConstantInt(&num))) + bind_values.push_back(num); + else + return false; } return true; } diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 9273c0294..2bc936fc4 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -611,156 +611,6 @@ def Torch_ConstantBoolOp : Torch_Op<"constant.bool", let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// Conversions to builtin types. -//===----------------------------------------------------------------------===// - -def Torch_ToBuiltinTensorOp : Torch_Op<"to_builtin_tensor", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert a `!torch.vtensor` to a `tensor`"; - let description = [{ - This op only operates on ValueTensorType, to avoid conflating conversions - between value-semantic and non-value-semantic types. - }]; - let arguments = (ins - Torch_ValueTensorType:$operand - ); - let results = (outs - AnyTensor:$result - ); - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; -} - -def Torch_FromBuiltinTensorOp : Torch_Op<"from_builtin_tensor", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert a `tensor` to a `!torch.vtensor`"; - let description = [{ - This op only operates on ValueTensorType, to avoid conflating conversions - between value-semantic and non-value-semantic types. - }]; - let arguments = (ins - AnyTensor:$operand - ); - let results = (outs - Torch_ValueTensorType:$result - ); - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; -} - -def Torch_ToI1Op : Torch_Op<"to_i1", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert a `!torch.bool` to an `i1`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - Torch_BoolType:$operand - ); - let results = (outs - I1:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - -def Torch_FromI1Op : Torch_Op<"from_i1", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert an `i1` to a `!torch.bool`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - I1:$operand - ); - let results = (outs - Torch_BoolType:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - -def Torch_ToI64Op : Torch_Op<"to_i64", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert a `!torch.int` to an `i64`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - Torch_IntType:$operand - ); - let results = (outs - I64:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - -def Torch_FromI64Op : Torch_Op<"from_i64", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert an `i64` to a `!torch.int`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - I64:$operand - ); - let results = (outs - Torch_IntType:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - -def Torch_ToF64Op : Torch_Op<"to_f64", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert a `!torch.float` to an `f64`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - Torch_FloatType:$operand - ); - let results = (outs - F64:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - -def Torch_FromF64Op : Torch_Op<"from_f64", [ - DeclareOpInterfaceMethods - ]> { - let summary = "Convert an `f64` to a `!torch.float`"; - let description = [{ - This op is primarily useful as a materialization during dialect conversion. - }]; - let arguments = (ins - F64:$operand - ); - let results = (outs - Torch_FloatType:$result - ); - let assemblyFormat = [{ - $operand attr-dict - }]; -} - //===----------------------------------------------------------------------===// // Additional ops used to model TorchScript's Graph's / Node's. //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.h b/include/npcomp/Dialect/Torch/Transforms/Passes.h index ea0558cd1..01205c463 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.h +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.h @@ -31,16 +31,14 @@ struct TorchLoweringPipelineOptions }; /// Creates a pipeline that lowers the object graph IR that is produced by -/// TorchScript import into the form expected by npcomp-verify-backend-contract. -void createLowerObjectGraphPipeline( +/// TorchScript import into the form expected by torch-verify-backend-contract. +void createTorchScriptToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); /// Creates a pipeline that lowers a flat list of funcs and global slots /// with the torch and aten dialects and mutable arrays and converts it to -/// the form required by npcomp-verify-backend-contract, in particular -/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to -/// linalg, converting torch.prim.* ops to elementary math operations. -void createLowerToNpcompBackendPipeline( +/// the form required by torch-verify-backend-contract. +void createGlobalizedModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); std::unique_ptr> createAdjustCallingConventionsPass(); @@ -55,14 +53,6 @@ std::unique_ptr> createMaximizeValueSemanticsPass(); std::unique_ptr> createRefinePublicReturnPass(); -std::unique_ptr> -createVerifyInvariantsBeforeBackendLoweringPass(); - -std::unique_ptr> createFuncBackendTypeConversionPass(); - -std::unique_ptr> -createFinalizingBackendTypeConversionPass(); - } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.td b/include/npcomp/Dialect/Torch/Transforms/Passes.td index 2665769bb..83d39a251 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.td +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.td @@ -214,40 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { }]; } -def VerifyInvariantsBeforeBackendLowering - : Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> { - let summary = "Verify invariants required by backend lowering"; - let constructor = - "mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass()"; - let description = [{ - This pass checks any invariants needed by the process of lowering the - `torch` dialect to the npcomp backend contract. - - The most important invariant is that all tensors should be ranked and have - a known dtype. It is useful to catch this early because it usually - represents a simple bug in RefineTypes, but can manifest as many different - kinds of obscure symptoms during lowering. - }]; -} - -def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> { - let summary = "Convert functions to operate on builtin tensors"; - let constructor = "mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass()"; - let description = [{ - Partial type conversion pass analogous in scope to the upstream - `func-bufferize` pass. See details there. - }]; -} - -def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> { - let summary = "Finalizes a partial conversion to builtin tensors"; - let constructor = - "mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass()"; - let description = [{ - Analogous in scope to the upstream `finalizing-bufferize` pass. - See details there. - }]; -} - #endif // NPCOMP_TORCH_PASSES diff --git a/include/npcomp/Dialect/TorchConversion/CMakeLists.txt b/include/npcomp/Dialect/TorchConversion/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/npcomp/Dialect/TorchConversion/IR/CMakeLists.txt b/include/npcomp/Dialect/TorchConversion/IR/CMakeLists.txt new file mode 100644 index 000000000..096018f23 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +set(LLVM_TARGET_DEFINITIONS TorchConversionOps.td) +mlir_tablegen(TorchConversionOps.h.inc -gen-op-decls) +mlir_tablegen(TorchConversionOps.cpp.inc -gen-op-defs) +mlir_tablegen(TorchConversionDialect.h.inc -gen-dialect-decls -dialect=torch_c) +mlir_tablegen(TorchConversionDialect.cpp.inc -gen-dialect-defs -dialect=torch_c) +add_public_tablegen_target(MLIRTorchConversionOpsIncGen) +add_dependencies(mlir-headers MLIRTorchConversionOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TorchConversionTypes.td) +mlir_tablegen(TorchConversionTypes.h.inc -gen-typedef-decls) +mlir_tablegen(TorchConversionTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRTorchConversionTypesIncGen) + +add_mlir_doc(TorchConversionDialect TorchConversionDialect TorchConversion/ -gen-dialect-doc) +add_mlir_doc(TorchConversionOps TorchConversionOps TorchConversion/ -gen-op-doc) diff --git a/include/npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td new file mode 100644 index 000000000..bdd2738e3 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td @@ -0,0 +1,29 @@ +//===-------------------------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHCONVERSION_BASE +#define TORCHCONVERSION_BASE + +include "mlir/IR/OpBase.td" + +def TorchConversion_Dialect : Dialect { + // `torch_conversion` is too verbose. + let name = "torch_c"; + let cppNamespace = "::mlir::NPCOMP::TorchConversion"; + let description = [{ + This dialect contains ops and transforms for converting from the Torch + backend contract to the npcomp backend contract. + + This mainly consists of converting ops and types from `torch` dialect + to the mix of dialects of the npcomp backend contract, such as tensor + ops being converted linalg-on-tensors, lists being converted to IREE lists, + and !torch.float being converted to `f64`. + }]; +} + +#endif // TORCHCONVERSION_BASE diff --git a/include/npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h new file mode 100644 index 000000000..addda9447 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h @@ -0,0 +1,16 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H +#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc" + +#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H diff --git a/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h new file mode 100644 index 000000000..73afd2638 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h @@ -0,0 +1,25 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H +#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H + +#include "iree-dialects/Dialect/IREE/IREEDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" + +#define GET_OP_CLASSES +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc" + +#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H diff --git a/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.td new file mode 100644 index 000000000..ca7912dbb --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -0,0 +1,207 @@ +//===-------------------------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHCONVERSION_OPS +#define TORCHCONVERSION_OPS + +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td" +include "npcomp/Dialect/Torch/IR/TorchTypes.td" +include "iree-dialects/Dialect/IREE/IREEDialect.td" + +class TorchConversion_Op traits = []> + : Op { +} + +//===----------------------------------------------------------------------===// +// Conversions to backend types. +//===----------------------------------------------------------------------===// + +def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert a `!torch.vtensor` to a `tensor`"; + let description = [{ + This op only operates on ValueTensorType, to avoid conflating conversions + between value-semantic and non-value-semantic types. + }]; + let arguments = (ins + Torch_ValueTensorType:$operand + ); + let results = (outs + AnyTensor:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert a `tensor` to a `!torch.vtensor`"; + let description = [{ + This op only operates on ValueTensorType, to avoid conflating conversions + between value-semantic and non-value-semantic types. + }]; + let arguments = (ins + AnyTensor:$operand + ); + let results = (outs + Torch_ValueTensorType:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert a `!torch.bool` to an `i1`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + Torch_BoolType:$operand + ); + let results = (outs + I1:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert an `i1` to a `!torch.bool`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + I1:$operand + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert a `!torch.int` to an `i64`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + Torch_IntType:$operand + ); + let results = (outs + I64:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert an `i64` to a `!torch.int`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + I64:$operand + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert a `!torch.float` to an `f64`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + Torch_FloatType:$operand + ); + let results = (outs + F64:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Convert an `f64` to a `!torch.float`"; + let description = [{ + This op is primarily useful as a materialization during dialect conversion. + }]; + let arguments = (ins + F64:$operand + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = [{ + $operand attr-dict + }]; +} + +// TODO: Verify the element types match. +def TorchConversion_ToIREEListOp : TorchConversion_Op<"to_iree_list", [ + ]> { + let summary = "Convert a `!torch.list` to a `!iree.list`"; + let description = [{ + }]; + let arguments = (ins + Torch_ListType:$operand + ); + let results = (outs + IREE_ListType:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +def TorchConversion_FromIREEListOp : TorchConversion_Op<"from_iree_list", [ + ]> { + let summary = "Convert a `!iree.list` to a `!torch.list`"; + let description = [{ + }]; + let arguments = (ins + IREE_ListType:$operand + ); + let results = (outs + Torch_ListType:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +#endif // TORCHCONVERSION_OPS diff --git a/include/npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h b/include/npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h similarity index 64% rename from include/npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h rename to include/npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h index 4ab976e67..f38c3e3f6 100644 --- a/include/npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h +++ b/include/npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h @@ -6,21 +6,26 @@ // //===----------------------------------------------------------------------===// -#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H -#define NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H +#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H +#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace NPCOMP { -namespace Torch { +namespace TorchConversion { + +/// Get the dependent dialects which might be involved in a backend type +/// conversion. +void getBackendTypeConversionDependentDialects(DialectRegistry ®istry); + /// Set up the provided ConversionTarget and TypeConverter for converting /// from `torch` dialect types to the types along the npcomp backend boundary /// (which currently consist only of builtin types). void setupBackendTypeConversion(ConversionTarget &target, TypeConverter &typeConverter); -} // namespace Torch +} // namespace TorchConversion } // namespace NPCOMP } // namespace mlir -#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H +#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H diff --git a/include/npcomp/Dialect/TorchConversion/Transforms/CMakeLists.txt b/include/npcomp/Dialect/TorchConversion/Transforms/CMakeLists.txt new file mode 100644 index 000000000..19c61527c --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(NPCOMPTorchConversionPassIncGen) + +add_mlir_doc(Passes NPCOMPTorchConversionTransforms ./ -gen-pass-doc) diff --git a/include/npcomp/Dialect/TorchConversion/Transforms/Passes.h b/include/npcomp/Dialect/TorchConversion/Transforms/Passes.h new file mode 100644 index 000000000..5ec2a8ae0 --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/Transforms/Passes.h @@ -0,0 +1,44 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H +#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" + +#include + +namespace mlir { +namespace NPCOMP { +namespace TorchConversion { + +/// Creates a pipeline that lowers the object graph IR that is produced by +/// TorchScript import into the form expected by npcomp-verify-backend-contract. +void createTorchScriptToNpcompBackendPipeline( + OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options); + +std::unique_ptr> +createVerifyInvariantsBeforeBackendLoweringPass(); + +std::unique_ptr> createFuncBackendTypeConversionPass(); + +std::unique_ptr> +createFinalizingBackendTypeConversionPass(); + +std::unique_ptr> createTmpDeleteDeadIREEListsPass(); + +} // namespace TorchConversion + +/// Registers all Torch transformation passes. +void registerTorchConversionPasses(); + +} // namespace NPCOMP +} // namespace mlir + +#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H diff --git a/include/npcomp/Dialect/TorchConversion/Transforms/Passes.td b/include/npcomp/Dialect/TorchConversion/Transforms/Passes.td new file mode 100644 index 000000000..22c0a08af --- /dev/null +++ b/include/npcomp/Dialect/TorchConversion/Transforms/Passes.td @@ -0,0 +1,76 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_TORCHCONVERSION_PASSES +#define NPCOMP_TORCHCONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def VerifyInvariantsBeforeBackendLowering + : Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> { + let summary = "Verify invariants required by backend lowering"; + let constructor = + "mlir::NPCOMP::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()"; + let description = [{ + This pass checks any invariants needed by the process of lowering the + `torch` dialect to the npcomp backend contract. + + The most important invariant is that all tensors should be ranked and have + a known dtype. It is useful to catch this early because it usually + represents a simple bug in RefineTypes, but can manifest as many different + kinds of obscure symptoms during lowering. + + TODO: This pass should probably be phrased as checking the + "torch backend contract" and moved to that dialect once we have more + substantial definition definition around what that layer is from an + "allowlist" perspective. + }]; +} + +def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> { + let summary = "Convert functions to operate on builtin tensors"; + let constructor = "mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass()"; + let description = [{ + Partial type conversion pass analogous in scope to the upstream + `func-bufferize` pass. See details there. + }]; +} + +def FinalizingBackendTypeConversion + : Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> { + let summary = "Finalizes a partial conversion to builtin tensors"; + let constructor = + "mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass()"; + let description = [{ + Analogous in scope to the upstream `finalizing-bufferize` pass. + See details there. + }]; +} + + +def TmpDeleteDeadIREELists + : Pass<"torch-tmp-delete-dead-lists", "FuncOp"> { + let summary = "Delete dead !iree.list ops"; + let constructor = + "mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass()"; + let description = [{ + Runs a few patterns to delete dead !iree.list ops until IREE can support + running them. Currently, these will get materialized as part of conversions + for ops like AtenConv2dOp that have list operands, even though they are dead + (for those ops, we pattern match a specific case of static constant lists). + Currently, this will break execution of those tests because the IREE + side of these ops still doesn't work (nor is IREE able to delete them + itself). + + TODO: Add support to IREE to run these ops E2E. + TODO: Remove this pass once IREE can run them e2e. + }]; +} + + +#endif // NPCOMP_TORCHCONVERSION_PASSES diff --git a/lib/Backend/Common/CMakeLists.txt b/lib/Backend/Common/CMakeLists.txt index c4bce4cc7..c4ab5be6e 100644 --- a/lib/Backend/Common/CMakeLists.txt +++ b/lib/Backend/Common/CMakeLists.txt @@ -17,6 +17,7 @@ add_npcomp_library(NPCOMPCommonBackend MLIRTensor MLIRStandard MLIRMath + IREEDialectsIREEDialect ) mlir_check_all_link_libraries(NPCOMPCommonBackend) diff --git a/lib/Backend/Common/VerifyBackendContract.cpp b/lib/Backend/Common/VerifyBackendContract.cpp index 895898beb..e751a090d 100644 --- a/lib/Backend/Common/VerifyBackendContract.cpp +++ b/lib/Backend/Common/VerifyBackendContract.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "iree-dialects/Dialect/IREE/IREEOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -32,6 +33,7 @@ class VerifyBackendContractPass return type; return nullptr; }); + converter.addConversion([](iree::ListType type) { return type; }); TypeConverter scalarConverter; for (TypeConverter *c : {&converter, &scalarConverter}) { c->addConversion([](FloatType type) { return type; }); @@ -59,8 +61,7 @@ class VerifyBackendContractPass // Tensor operations should go through linalg and the tensor dialect. target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); - // DimOp is used to query tensor sizes. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalDialect(opHasLegalTypes); // AssertOp is used to terminate the program for error guards. target.addLegalOp(); diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index de6ecd136..6d5cd3cbb 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,6 +32,7 @@ add_npcomp_library(NPCOMPInitAll NPCOMPRefBackend NPCOMPRefbackDialect NPCOMPTorchDialect + NPCOMPTorchConversionDialect NPCOMPRefbackrtDialect NPCOMPBasicpyDialect NPCOMPBasicpyPasses @@ -39,6 +40,7 @@ add_npcomp_library(NPCOMPInitAll NPCOMPNumpyDialect NPCOMPNumpyPasses NPCOMPTypingPasses + IREEDialectsIREEDialect # TODO: We shouldn't need npcomp_conversion_libs here, but we have # some dialect transform libraries accumulating into that property. diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index c0511e90b..83ed81724 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(TorchToIREE) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToStd) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 62acbf3c3..245416862 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,6 +9,7 @@ #include "npcomp/Conversion/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h" +#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h" #include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" #include "npcomp/Conversion/TorchToSCF/TorchToSCF.h" #include "npcomp/Conversion/TorchToStd/TorchToStd.h" diff --git a/lib/Conversion/TorchToIREE/CMakeLists.txt b/lib/Conversion/TorchToIREE/CMakeLists.txt new file mode 100644 index 000000000..918a6a380 --- /dev/null +++ b/lib/Conversion/TorchToIREE/CMakeLists.txt @@ -0,0 +1,19 @@ +add_npcomp_conversion_library(NPCOMPTorchToIREE + TorchToIREE.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToIREE + + DEPENDS + NPCOMPConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + NPCOMPTorchDialect + MLIRStandard + IREEDialectsIREEDialect +) diff --git a/lib/Conversion/TorchToIREE/TorchToIREE.cpp b/lib/Conversion/TorchToIREE/TorchToIREE.cpp new file mode 100644 index 000000000..ca8a4046e --- /dev/null +++ b/lib/Conversion/TorchToIREE/TorchToIREE.cpp @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h" + +#include "../PassDetail.h" +#include "iree-dialects/Dialect/IREE/IREEOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Torch; + +//===----------------------------------------------------------------------===// +// The patterns +//===----------------------------------------------------------------------===// + +namespace { +class ConvertPrimListConstructOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PrimListConstructOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = getTypeConverter()->convertType(op.getType()); + auto capacity = + rewriter.create(op.getLoc(), op->getNumOperands()); + auto ireeList = + rewriter.replaceOpWithNewOp(op, type, capacity); + for (int i = 0, e = operands.size(); i != e; ++i) { + auto index = rewriter.create(op.getLoc(), i); + rewriter.create(op.getLoc(), ireeList, index, + operands[i]); + } + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// The pass +//===----------------------------------------------------------------------===// + +namespace { +class ConvertTorchToIREE : public ConvertTorchToIREEBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context); + target.addIllegalOp(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createConvertTorchToIREEPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c00912e88..265504142 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -16,7 +16,8 @@ #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" -#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -63,7 +64,7 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op, // to end. Constant values can be be extracted directly and non constant // list values are not supported. // TODO: loose this constraint when properly support list type -static bool isConstantIntListMatching(Value &value, +static bool isConstantIntListMatching(Value value, llvm::SmallVectorImpl &expects) { llvm::SmallVector intValues; if (!matchPattern(value, m_TorchConstantIntList(intValues))) @@ -171,7 +172,6 @@ public: Location loc = op->getLoc(); MLIRContext *context = op->getContext(); AtenAdaptiveAvgPool2dOp::Adaptor adaptor(operands); - Value outputSize = adaptor.output_size(); Value input = adaptor.self(); /* in form of N*C*H*W */ RankedTensorType inputType = input.getType().cast(); Type elementType = inputType.getElementType(); @@ -183,7 +183,10 @@ public: return rewriter.notifyMatchFailure(op, "input should be rank 4"); SmallVector expects{1, 1}; - if (!isConstantIntListMatching(outputSize, expects)) + // Pattern match against the op's original operands, because otherwise we + // will get the lowered version of the operands which is harder to pattern + // match. + if (!isConstantIntListMatching(op.output_size(), expects)) return rewriter.notifyMatchFailure( op, "only support output_size with H and W both equal to constant 1"); @@ -269,9 +272,6 @@ public: AtenConv2dOp::Adaptor adaptor(operands); Value input = adaptor.input(); /* in form of N*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */ - Value padding = adaptor.padding(); - Value stride = adaptor.stride(); - Value dilation = adaptor.dilation(); Value groups = adaptor.groups(); Type elementType = @@ -291,18 +291,21 @@ public: Value weightH = getDimOp(rewriter, loc, weight, 2); Value weightW = getDimOp(rewriter, loc, weight, 3); + // Pattern match against the op's original operands, because otherwise we + // will get the lowered version of the operands which is harder to pattern + // match. llvm::SmallVector paddingInts; - if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) { + if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) { return rewriter.notifyMatchFailure( op, "only support constant padding values"); } llvm::SmallVector strideInts; - if (!matchPattern(stride, m_TorchConstantIntList(strideInts))) + if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); llvm::SmallVector dilationInts; - if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts))) + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (!op.bias().getType().isa()) @@ -905,30 +908,29 @@ public: Location loc = op->getLoc(); AtenMaxPool2dOp::Adaptor adaptor(operands); Value self = adaptor.self(); - Value kernelSize = adaptor.kernel_size(); - Value stride = adaptor.stride(); - Value padding = adaptor.padding(); - Value dilation = adaptor.dilation(); Value ceilMode = adaptor.ceil_mode(); Type elementType = self.getType().cast().getElementType(); if (!elementType.isa()) op.emitError("unimplemented: non-floating point type"); + // Pattern match against the op's original operands, because otherwise we + // will get the lowered version of the operands which is harder to pattern + // match. llvm::SmallVector strideInts; - if (!matchPattern(stride, m_TorchConstantIntList(strideInts))) + if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); llvm::SmallVector dilationInts; - if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts))) + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); llvm::SmallVector paddingInts; - if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) + if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) return rewriter.notifyMatchFailure(op, "only support constant int paddings"); llvm::SmallVector kernelSizeInts; - if (!matchPattern(kernelSize, m_TorchConstantIntList(kernelSizeInts))) + if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) return rewriter.notifyMatchFailure(op, "only support kernel size ints"); Value falseValue = rewriter.create( @@ -1113,6 +1115,7 @@ public: registry.insert(); registry.insert(); registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { @@ -1123,7 +1126,7 @@ public: TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToSCF/CMakeLists.txt b/lib/Conversion/TorchToSCF/CMakeLists.txt index f4f12ba64..b41f958bb 100644 --- a/lib/Conversion/TorchToSCF/CMakeLists.txt +++ b/lib/Conversion/TorchToSCF/CMakeLists.txt @@ -16,4 +16,5 @@ add_npcomp_conversion_library(NPCOMPTorchToSCF MLIRSCF MLIRStandard NPCOMPTorchDialect + NPCOMPTorchConversionDialect ) diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 73069808a..30f7c38a5 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -13,7 +13,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" -#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -63,6 +64,7 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { @@ -72,7 +74,7 @@ public: TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index e81566dc3..dde93e044 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -14,7 +14,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" -#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::NPCOMP; @@ -93,6 +94,7 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { @@ -102,7 +104,7 @@ public: TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 2068c869e..e48cf7467 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(Numpy) add_subdirectory(Refback) add_subdirectory(Refbackrt) add_subdirectory(Torch) +add_subdirectory(TorchConversion) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b5136a04e..4eda85c2d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -688,35 +688,6 @@ void CopyToValueTensorOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), getOperand()); } -//===----------------------------------------------------------------------===// -// ToBuiltinTensorOp -//===----------------------------------------------------------------------===// - -LogicalResult ToBuiltinTensorOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto resultType = - operands[0].getType().cast().toBuiltinTensor(); - if (!resultType) - return failure(); - inferredReturnTypes.push_back(resultType); - return success(); -} - -//===----------------------------------------------------------------------===// -// FromBuiltinTensorOp -//===----------------------------------------------------------------------===// - -LogicalResult FromBuiltinTensorOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back( - ValueTensorType::getFromShaped(operands[0].getType().cast())); - return success(); -} - //===----------------------------------------------------------------------===// // ConstantNoneOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index c77542dbb..7559445a3 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses AdjustCallingConventions.cpp - BackendTypeConversion.cpp Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp @@ -9,7 +8,6 @@ add_npcomp_conversion_library(NPCOMPTorchPasses ReduceOpVariants.cpp RefinePublicReturn.cpp RefineTypes.cpp - VerifyInvariantsBeforeBackendLowering.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms @@ -24,7 +22,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses MLIRIR MLIRPass NPCOMPTorchDialect - NPCOMPTorchToLinalg NPCOMPInterfaces - MLIRMemRefTransforms ) diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index aa97c9cd1..4a881201f 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -7,15 +7,8 @@ //===----------------------------------------------------------------------===// #include "npcomp/Dialect/Torch/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "npcomp/Backend/Common/Passes.h" -#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h" -#include "npcomp/Conversion/TorchToStd/TorchToStd.h" //===----------------------------------------------------------------------===// // Pass registration @@ -29,16 +22,16 @@ namespace { void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); mlir::PassPipelineRegistration( - "torchscript-to-npcomp-backend-pipeline", - "Pipeline lowering torch object graph to npcomp backend format.", - mlir::NPCOMP::Torch::createLowerObjectGraphPipeline); + "torchscript-to-torch-backend-pipeline", + "Pipeline lowering TorchScript object graph IR to Torch backend form.", + mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline); mlir::PassPipelineRegistration( - "torch-globalized-module-to-npcomp-backend-pipeline", - "Pipeline lowering to npcomp backend form.", - mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline); + "torch-globalized-module-to-torch-backend-pipeline", + "Pipeline lowering a globalized Torch program to Torch backend form.", + mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline); } -void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline( +void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // When we import TorchScript IR, we import their entire "compilation unit", // which can contain numerous functions unrelated to the current program, @@ -66,10 +59,10 @@ void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline( // Incorporate user annotations and remove signature Python-isms. pm.addPass(createAdjustCallingConventionsPass()); - createLowerToNpcompBackendPipeline(pm, options); + createGlobalizedModuleToTorchBackendPipeline(pm, options); } -void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( +void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // General considerations: As a matter of bring-up, we are simultaneously // building out the frontend pipeline and also co-developing the backend @@ -140,39 +133,5 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( pm.addNestedPass(createCanonicalizerPass()); } - //===--------------------------------------------------------------------===// - // Lowering ops and the !torch.vtensor type to the npcomp backend contract. - //===--------------------------------------------------------------------===// - - // Check some invariants to catch errors in a clear way. - pm.addPass(Torch::createVerifyInvariantsBeforeBackendLoweringPass()); - - // Lower to linalg + guards which is the input to codegen backends. - // We do this first as it tends to involve pattern-matching against constants, - // (e.g. dimensions which must be constant in a ranked programming model) - // and those constants get somewhat obscured by TorchToStd. - pm.addNestedPass(createConvertTorchToLinalgPass()); - - pm.addNestedPass(createConvertTorchToStdPass()); - pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(createStdExpandOpsPass()); - - if (options.optimize) { - // Clean up any non-canonical code introduced in our linalg lowering. - pm.addNestedPass(createCanonicalizerPass()); - // Resolve `dim` ops on tensors (which currently live in the `memref` - // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass(memref::createResolveShapedTypeResultDimsPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } - - // Finish the type conversion from !torch.vtensor to the builtin tensor type. - pm.addPass(createFuncBackendTypeConversionPass()); - pm.addNestedPass(createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to the form that backends expect. - // This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(CommonBackend::createVerifyBackendContractPass()); + // TODO: VerifyTorchBackendContractPass. } diff --git a/lib/Dialect/TorchConversion/CMakeLists.txt b/lib/Dialect/TorchConversion/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/lib/Dialect/TorchConversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/TorchConversion/IR/CMakeLists.txt b/lib/Dialect/TorchConversion/IR/CMakeLists.txt new file mode 100644 index 000000000..062b121b0 --- /dev/null +++ b/lib/Dialect/TorchConversion/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +add_npcomp_dialect_library(NPCOMPTorchConversionDialect + TorchConversionDialect.cpp + TorchConversionOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion + + DEPENDS + MLIRTorchConversionOpsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + MLIRSideEffectInterfaces +) diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp new file mode 100644 index 000000000..83138cfee --- /dev/null +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::TorchConversion; + +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TorchConversionInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + BlockAndValueMapping &) const final { + return true; + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Dialect initialize method. +//===----------------------------------------------------------------------===// + +void TorchConversionDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" + >(); + addInterfaces(); +} diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp new file mode 100644 index 000000000..83dc0ae58 --- /dev/null +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" +#include "llvm/ADT/StringMap.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::TorchConversion; + +//===----------------------------------------------------------------------===// +// ToBuiltinTensorOp +//===----------------------------------------------------------------------===// + +LogicalResult ToBuiltinTensorOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto resultType = + operands[0].getType().cast().toBuiltinTensor(); + if (!resultType) + return failure(); + inferredReturnTypes.push_back(resultType); + return success(); +} + +//===----------------------------------------------------------------------===// +// FromBuiltinTensorOp +//===----------------------------------------------------------------------===// + +LogicalResult FromBuiltinTensorOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(Torch::ValueTensorType::getFromShaped( + operands[0].getType().cast())); + return success(); +} + +#define GET_OP_CLASSES +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc" diff --git a/lib/Dialect/Torch/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp similarity index 76% rename from lib/Dialect/Torch/Transforms/BackendTypeConversion.cpp rename to lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index ded951e1c..877eeb6bb 100644 --- a/lib/Dialect/Torch/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -8,24 +8,36 @@ #include "PassDetail.h" +#include "iree-dialects/Dialect/IREE/IREEDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" -#include "npcomp/Dialect/Torch/IR/TorchOps.h" -#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" -#include "npcomp/Dialect/Torch/Transforms/Passes.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::NPCOMP; -using namespace mlir::NPCOMP::Torch; +using namespace mlir::NPCOMP::TorchConversion; + +void mlir::NPCOMP::TorchConversion::getBackendTypeConversionDependentDialects( + DialectRegistry ®istry) { + registry.insert(); + registry.insert(); +} + +//===----------------------------------------------------------------------===// +// Type conversion setup. +//===----------------------------------------------------------------------===// static void setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, TypeConverter &typeConverter) { - target.addLegalOp(); + target.addLegalOp(); typeConverter.addConversion( [](Torch::ValueTensorType type) -> Optional { return type.toBuiltinTensor(); @@ -34,10 +46,11 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(inputs[0].getType().isa()); return builder.create(loc, inputs[0]); }); - auto sourceMaterialization = [](OpBuilder &builder, ValueTensorType type, + auto sourceMaterialization = [](OpBuilder &builder, + Torch::ValueTensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); @@ -49,7 +62,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, static void setupTorchBoolToI1Conversion(ConversionTarget &target, TypeConverter &typeConverter) { - target.addLegalOp(); + target.addLegalOp(); typeConverter.addConversion([](Torch::BoolType type) -> Optional { return IntegerType::get(type.getContext(), 1); }); @@ -75,7 +88,7 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, static void setupTorchIntToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { - target.addLegalOp(); + target.addLegalOp(); typeConverter.addConversion([](Torch::IntType type) -> Optional { return IntegerType::get(type.getContext(), 64); }); @@ -101,7 +114,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, static void setupTorchFloatToF64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { - target.addLegalOp(); + target.addLegalOp(); typeConverter.addConversion([](Torch::FloatType type) -> Optional { return Float64Type::get(type.getContext()); }); @@ -122,12 +135,38 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addArgumentMaterialization(sourceMaterialization); } -void mlir::NPCOMP::Torch::setupBackendTypeConversion( +static void setupTorchListToIREEListConversion(ConversionTarget &target, + TypeConverter &typeConverter) { + target.addLegalOp(); + typeConverter.addConversion([&](Torch::ListType type) -> Optional { + return iree::ListType::get( + type.getContext(), typeConverter.convertType(type.getContainedType())); + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, iree::ListType type, ValueRange inputs, + Location loc) -> Optional { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]).getResult(); + }); + auto sourceMaterialization = [](OpBuilder &builder, Torch::ListType type, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }; + typeConverter.addSourceMaterialization(sourceMaterialization); + typeConverter.addArgumentMaterialization(sourceMaterialization); +} + +void mlir::NPCOMP::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { setupValueTensorToBuiltinTensorConversion(target, typeConverter); setupTorchBoolToI1Conversion(target, typeConverter); setupTorchIntToI64Conversion(target, typeConverter); setupTorchFloatToF64Conversion(target, typeConverter); + setupTorchListToIREEListConversion(target, typeConverter); } //===----------------------------------------------------------------------===// @@ -139,6 +178,9 @@ struct FuncBackendTypeConversionPass : public FuncBackendTypeConversionBase { using FuncBackendTypeConversionBase< FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -147,7 +189,7 @@ struct FuncBackendTypeConversionPass RewritePatternSet patterns(context); ConversionTarget target(*context); typeConverter.addConversion([](Type type) { return type; }); - setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversion(target, typeConverter); populateFuncOpTypeConversionPattern(patterns, typeConverter); target.addDynamicallyLegalOp([&](FuncOp op) { @@ -176,7 +218,7 @@ struct FuncBackendTypeConversionPass } // namespace std::unique_ptr> -mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass() { +mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass() { return std::make_unique(); } @@ -234,13 +276,13 @@ struct FinalizingBackendTypeConversionPass ConversionTarget target(*context); typeConverter.addConversion([](Type type) { return type; }); - setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversion(target, typeConverter); // Mark materializations as illegal in this pass (since we are finalizing) // and add patterns that eliminate them. setupFinalization(target, patterns, - typeConverter); + FromI64Op, ToI64Op, FromF64Op, ToF64Op, FromIREEListOp, + ToIREEListOp>(target, patterns, typeConverter); // If all result types are legal, and all block arguments are legal, then // all types in the program are legal. @@ -259,6 +301,6 @@ struct FinalizingBackendTypeConversionPass } // namespace std::unique_ptr> -mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass() { +mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt new file mode 100644 index 000000000..b4311cdd5 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -0,0 +1,28 @@ +add_npcomp_conversion_library(NPCOMPTorchConversionPasses + BackendTypeConversion.cpp + Passes.cpp + TmpDeleteDeadIREELists.cpp + VerifyInvariantsBeforeBackendLowering.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion/Transforms + + DEPENDS + NPCOMPTorchConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + NPCOMPTorchConversionDialect + NPCOMPTorchDialect + NPCOMPTorchPasses + NPCOMPTorchToIREE + NPCOMPTorchToLinalg + NPCOMPTorchToStd + NPCOMPTorchToSCF + NPCOMPInterfaces + MLIRMemRefTransforms +) diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h new file mode 100644 index 000000000..e9196e589 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -0,0 +1,25 @@ +//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H +#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace NPCOMP { +namespace TorchConversion { + +#define GEN_PASS_CLASSES +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc" + +} // namespace TorchConversion +} // namespace NPCOMP +} // end namespace mlir + +#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp new file mode 100644 index 000000000..f43ece1d2 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "npcomp/Backend/Common/Passes.h" +#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h" +#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h" +#include "npcomp/Conversion/TorchToStd/TorchToStd.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc" +} // end namespace + +void mlir::NPCOMP::registerTorchConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration( + "torchscript-to-npcomp-backend-pipeline", + "Pipeline lowering torch object graph to npcomp backend format.", + mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline); +} + +void mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline( + OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + + // Conversion to the npcomp backend contract starts from the Torch backend + // contract. + Torch::createTorchScriptToTorchBackendPipeline(pm, options); + + // Check some invariants to catch errors in a clear way. + pm.addPass( + TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); + + // Lower to linalg + guards which is the input to codegen backends. + // We do this first as it tends to involve pattern-matching against constants, + // (e.g. dimensions which must be constant in a ranked programming model) + // and those constants get somewhat obscured by TorchToStd. + pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createConvertTorchToStdPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + // Lists and other concepts that don't exist in upstream go through the IREE + // dialect, which we treat as an reasonably well designed interim placeholder + // for the set of ops that we think makes sense in the npcomp backend + // contract. We expect to co-evolve this dialect with npcomp needs, as a lot + // of what we are doing here in npcomp is breaking new ground w.r.t. + // expressiveness and program generality for tensor compilers. + // + // We lower lists last because the lowered form is much harder to reason about + // than the original form. + pm.addNestedPass(createConvertTorchToIREEPass()); + pm.addNestedPass(createStdExpandOpsPass()); + + if (options.optimize) { + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // Resolve `dim` ops on tensors (which currently live in the `memref` + // dialect for some reason -- we don't have memrefs at this level). + pm.addNestedPass(memref::createResolveShapedTypeResultDimsPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + } + + // Finish the type conversion from `torch` types to the types of the npcomp + // backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Temporarily delete dead list ops until IREE can run them e2e. + // TODO: Remove this pass once IREE can run them e2e. + // TODO: Add support to IREE to run these ops E2E. + pm.addNestedPass(TorchConversion::createTmpDeleteDeadIREEListsPass()); + + // Verify that we have lowered to the form that npcomp backends expect. + // This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(CommonBackend::createVerifyBackendContractPass()); +} diff --git a/lib/Dialect/TorchConversion/Transforms/TmpDeleteDeadIREELists.cpp b/lib/Dialect/TorchConversion/Transforms/TmpDeleteDeadIREELists.cpp new file mode 100644 index 000000000..9c3bf3985 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/TmpDeleteDeadIREELists.cpp @@ -0,0 +1,56 @@ +//===- TmpDeleteDeadIREELists.cpp --------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "iree-dialects/Dialect/IREE/IREEOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::TorchConversion; + +namespace { +class TmpDeleteDeadIREEListsPass + : public TmpDeleteDeadIREEListsBase { + void runOnOperation() override { + SmallVector toErase; + // Delete lists that are only set (but not read from). + // This is created by our lowering for torch.prim.ListConstruct. + // Until IREE can run such ops e2e (or delete them itself), we need to + // do this cleanup. + // TODO: Add support to IREE to run these ops E2E. + getOperation().walk([&](iree::ListCreateOp op) { + SmallVector deadOps; + deadOps.push_back(op); + for (auto &use : op.getResult().getUses()) { + if (isa(use.getOwner())) { + deadOps.push_back(use.getOwner()); + } else { + // We can't analyze the list op if it is used by something else. + return; + } + } + llvm::append_range(toErase, deadOps); + }); + for (auto *op : toErase) { + op->dropAllDefinedValueUses(); + op->erase(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/VerifyInvariantsBeforeBackendLowering.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp similarity index 87% rename from lib/Dialect/Torch/Transforms/VerifyInvariantsBeforeBackendLowering.cpp rename to lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp index 6f1daf802..e2a3ff1ea 100644 --- a/lib/Dialect/Torch/Transforms/VerifyInvariantsBeforeBackendLowering.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp @@ -11,17 +11,18 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" -#include "npcomp/Dialect/Torch/Transforms/Passes.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::NPCOMP; -using namespace mlir::NPCOMP::Torch; +using namespace mlir::NPCOMP::TorchConversion; static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) { // TODO: Make this an allowlist instead of a denylist. // TODO: Make this stricter. auto type = v.getType(); - if (auto valueTensorType = type.dyn_cast()) { + if (auto valueTensorType = type.dyn_cast()) { if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes()) return errorReportOp->emitError() .append("unsupported by backend lowering: tensor with unknown rank " @@ -77,7 +78,7 @@ class VerifyInvariantsBeforeBackendLoweringPass } // namespace -std::unique_ptr> -mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass() { +std::unique_ptr> mlir::NPCOMP::TorchConversion:: + createVerifyInvariantsBeforeBackendLoweringPass() { return std::make_unique(); } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 9033a962f..c383d56f8 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -8,6 +8,7 @@ #include "npcomp/InitAll.h" +#include "iree-dialects/Dialect/IREE/IREEDialect.h" #include "mlir/IR/Dialect.h" #include "npcomp/Backend/Common/Passes.h" #include "npcomp/Backend/IREE/Passes.h" @@ -20,6 +21,8 @@ #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h" +#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h" #include "npcomp/RefBackend/RefBackend.h" #include "npcomp/Typing/Transforms/Passes.h" @@ -29,7 +32,9 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry ®istry) { Numpy::NumpyDialect, refbackrt::RefbackrtDialect, refback::RefbackDialect, - mlir::NPCOMP::Torch::TorchDialect>(); + mlir::NPCOMP::Torch::TorchDialect, + mlir::NPCOMP::TorchConversion::TorchConversionDialect, + iree::IREEDialect>(); // clang-format on } @@ -39,6 +44,7 @@ void mlir::NPCOMP::registerAllPasses() { mlir::NPCOMP::registerBasicpyPasses(); mlir::NPCOMP::registerNumpyPasses(); mlir::NPCOMP::registerTorchPasses(); + mlir::NPCOMP::registerTorchConversionPasses(); mlir::NPCOMP::registerTypingPasses(); mlir::NPCOMP::IREEBackend::registerIREEBackendPasses(); mlir::NPCOMP::CommonBackend::registerCommonBackendPasses(); diff --git a/test/Conversion/TorchToIREE/basic.mlir b/test/Conversion/TorchToIREE/basic.mlir new file mode 100644 index 000000000..b51c44fff --- /dev/null +++ b/test/Conversion/TorchToIREE/basic.mlir @@ -0,0 +1,19 @@ + +// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: builtin.func @forward( +// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list { +// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]] +// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]] +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[LIST:.*]] = iree.list.create %[[C2]] : !iree.list +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: iree.list.set %[[LIST]][%[[C0]]], %[[ARG]] : !iree.list, f64 +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: iree.list.set %[[LIST]][%[[C1]]], %[[ALSO_ARG]] : !iree.list, f64 +// CHECK: %[[LIST_TORCH:.*]] = torch_c.from_iree_list %[[LIST]] : !iree.list -> !torch.list +// CHECK: return %[[LIST_TORCH]] : !torch.list +builtin.func @forward(%arg0: !torch.float) -> !torch.list { + %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list + return %0 : !torch.list +} diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index c99ce2b96..413862067 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: func @torch.aten.mm$basic( // CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { -// CHECK: %[[LHS:.*]] = torch.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RHS:.*]] = torch.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = constant 1 : index @@ -20,7 +20,7 @@ // CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor -> tensor // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ZEROFILL]] : tensor) -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor to tensor -// CHECK: %[[RESULT_VTENSOR:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,2],f32> +// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,2],f32> // CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32> func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index 4f34bede4..81fd75eff 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): @@ -11,7 +11,7 @@ // CHECK: linalg.yield %[[TANH]] : f32 // CHECK: } -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: } func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { @@ -22,8 +22,8 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> // CHECK-LABEL: func @elementwise$binary( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[BUILTIN_ARG0:.*]] = torch.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[BUILTIN_ARG1:.*]] = torch.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor +// CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = constant 1 : index @@ -39,7 +39,7 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> // CHECK: linalg.yield %[[MUL]] : f32 // CHECK: } -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> @@ -61,7 +61,7 @@ func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[C1:.*]] = torch.constant.int 1 -// CHECK: %[[BUILTIN_C1:.*]] = torch.to_i64 %[[C1]] +// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]] // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32 diff --git a/test/Conversion/TorchToLinalg/flatten.mlir b/test/Conversion/TorchToLinalg/flatten.mlir index f98499e80..3e7d13d18 100644 --- a/test/Conversion/TorchToLinalg/flatten.mlir +++ b/test/Conversion/TorchToLinalg/flatten.mlir @@ -4,10 +4,10 @@ // CHECK-LABEL: func @torch.aten.flatten.using_ints$basic( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { @@ -21,10 +21,10 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5], // CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { @@ -38,10 +38,10 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2, // CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { @@ -55,10 +55,10 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2 // CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor -> !torch.vtensor<[?,12],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor -> !torch.vtensor<[?,12],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32> func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { @@ -72,9 +72,9 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2] // CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0( // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir new file mode 100644 index 000000000..5fb72cc2f --- /dev/null +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -0,0 +1,27 @@ +// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: builtin.func @forward +builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %int7 = torch.constant.int 7 + %int8 = torch.constant.int 8 + %false = torch.constant.bool false + // CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] + // CHECK: %[[NEUTRAL:.*]] = constant -1.401300e-45 : f32 + // CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor -> tensor + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor + // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + %kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} diff --git a/test/Conversion/TorchToLinalg/unsqueeze.mlir b/test/Conversion/TorchToLinalg/unsqueeze.mlir index 4b0863b9d..3a43aad4f 100644 --- a/test/Conversion/TorchToLinalg/unsqueeze.mlir +++ b/test/Conversion/TorchToLinalg/unsqueeze.mlir @@ -5,9 +5,9 @@ // CHECK-LABEL: func @torch.aten.unsqueeze$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { %int0 = torch.constant.int 0 @@ -17,9 +17,9 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtenso // CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor into tensor<1xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { %int-1 = torch.constant.int -1 @@ -29,9 +29,9 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !tor // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32> func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { %int0 = torch.constant.int 0 @@ -41,9 +41,9 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32> func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { %int-1 = torch.constant.int -1 @@ -53,9 +53,9 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32> -// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32> func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { %int2 = torch.constant.int 2 diff --git a/test/Conversion/TorchToSCF/basic.mlir b/test/Conversion/TorchToSCF/basic.mlir index e482253d6..91d13c76d 100644 --- a/test/Conversion/TorchToSCF/basic.mlir +++ b/test/Conversion/TorchToSCF/basic.mlir @@ -4,15 +4,15 @@ // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]] +// CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) { -// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]] +// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_1]] // CHECK: scf.yield %[[VAL_5]] : i64 // CHECK: } else { -// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]] +// CHECK: %[[VAL_6:.*]] = torch_c.to_i64 %[[VAL_2]] // CHECK: scf.yield %[[VAL_6]] : i64 // CHECK: } -// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]] +// CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]] // CHECK: return %[[VAL_7]] : !torch.int func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { %int2 = torch.constant.int 2 @@ -31,22 +31,22 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = torch.constant.int 3 // CHECK: %[[VAL_4:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]] +// CHECK: %[[VAL_5:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) { -// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]] +// CHECK: %[[VAL_7:.*]] = torch_c.to_i1 %[[VAL_1]] // CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) { -// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]] +// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_2]] // CHECK: scf.yield %[[VAL_9]] : i64 // CHECK: } else { -// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]] +// CHECK: %[[VAL_10:.*]] = torch_c.to_i64 %[[VAL_3]] // CHECK: scf.yield %[[VAL_10]] : i64 // CHECK: } // CHECK: scf.yield %[[VAL_11:.*]] : i64 // CHECK: } else { -// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_12:.*]] = torch_c.to_i64 %[[VAL_4]] // CHECK: scf.yield %[[VAL_12]] : i64 // CHECK: } -// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]] +// CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]] // CHECK: return %[[VAL_13]] : !torch.int func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int { %int2 = torch.constant.int 2 diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index c46a50eb1..438d42029 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -3,10 +3,10 @@ // CHECK-LABEL: func @torch.aten.dim( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32> // CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32> // CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK]] : index to i64 -// CHECK: %[[RANK_TORCH_INT:.*]] = torch.from_i64 %[[RANK_I64]] +// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]] // CHECK: return %[[RANK_TORCH_INT]] : !torch.int func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { %0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int @@ -16,10 +16,10 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { // CHECK-LABEL: func @torch.aten.ne.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]] +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64 -// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]] +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { %0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool @@ -29,10 +29,10 @@ func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { // CHECK-LABEL: func @torch.aten.gt.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]] +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64 -// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]] +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { %0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool @@ -41,7 +41,7 @@ func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { // CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor -// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32> func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { %0 = torch.vtensor.literal(dense<0.0> : tensor) : !torch.vtensor<[],f32> @@ -50,7 +50,7 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK-LABEL: func @torch.constant.bool() -> !torch.bool { // CHECK: %[[CST:.*]] = constant true -// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]] +// CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]] // CHECK: return %[[BOOL]] : !torch.bool func @torch.constant.bool() -> !torch.bool { %true = torch.constant.bool true @@ -59,7 +59,7 @@ func @torch.constant.bool() -> !torch.bool { // CHECK-LABEL: func @torch.constant.float() -> !torch.float { // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64 -// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]] +// CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]] // CHECK: return %[[FLOAT]] : !torch.float func @torch.constant.float() -> !torch.float { %float = torch.constant.float 1.000000e+00 @@ -68,7 +68,7 @@ func @torch.constant.float() -> !torch.float { // CHECK-LABEL: func @torch.constant.int() -> !torch.int { // CHECK: %[[CST:.*]] = constant 1 : i64 -// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]] +// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]] // CHECK: return %[[INT]] : !torch.int func @torch.constant.int() -> !torch.int { %int1 = torch.constant.int 1 diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index e718e84e4..493acd6b9 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -13,19 +13,6 @@ func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams } -// CHECK-LABEL: func @builtin_tensor_interop( -func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) { - // CHECK: torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> - %0 = torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> - // CHECK: torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8> - %1 = torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8> - // CHECK: torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32> - %2 = torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32> - // CHECK: torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8> - %3 = torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8> - return -} - // CHECK: @tensor.default() -> !torch.tensor func private @tensor.default() -> !torch.tensor // CHECK: @tensor.default_explicit() -> !torch.tensor{{$}} diff --git a/test/Dialect/Torch/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir similarity index 66% rename from test/Dialect/Torch/finalizing-backend-type-conversion.mlir rename to test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index abf5d2469..02f22f4bd 100644 --- a/test/Dialect/Torch/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -7,8 +7,8 @@ // CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { // CHECK: return %[[ARG]] : tensor func @eliminate_materializations(%arg0: tensor) -> tensor { - %0 = torch.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[],f32> - %1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor + %0 = torch_c.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[],f32> + %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor return %1 : tensor } @@ -19,8 +19,8 @@ func @eliminate_materializations(%arg0: tensor) -> tensor { // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { // CHECK: return %[[ARG]] : i1 func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 { - %0 = torch.from_i1 %arg0 - %1 = torch.to_i1 %0 + %0 = torch_c.from_i1 %arg0 + %1 = torch_c.to_i1 %0 return %1 : i1 } @@ -28,8 +28,8 @@ func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 { // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { // CHECK: return %[[ARG]] : i64 func @eliminate_materializations$torch.int(%arg0: i64) -> i64 { - %0 = torch.from_i64 %arg0 - %1 = torch.to_i64 %0 + %0 = torch_c.from_i64 %arg0 + %1 = torch_c.to_i64 %0 return %1 : i64 } @@ -37,24 +37,33 @@ func @eliminate_materializations$torch.int(%arg0: i64) -> i64 { // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { // CHECK: return %[[ARG]] : f64 func @eliminate_materializations$torch.float(%arg0: f64) -> f64 { - %0 = torch.from_f64 %arg0 - %1 = torch.to_f64 %0 + %0 = torch_c.from_f64 %arg0 + %1 = torch_c.to_f64 %0 return %1 : f64 } +// CHECK-LABEL: func @eliminate_materializations$torch.list( +// CHECK-SAME: %[[ARG:.*]]: !iree.list) -> !iree.list { +// CHECK: return %[[ARG]] : !iree.list +func @eliminate_materializations$torch.list(%arg0: !iree.list) -> !iree.list { + %0 = torch_c.from_iree_list %arg0 : !iree.list -> !torch.list + %1 = torch_c.to_iree_list %0 : !torch.list -> !iree.list + return %1 : !iree.list +} + // ----- func @unable_to_convert_lone_buffer_cast() -> tensor { // expected-error @+1 {{failed to legalize operation 'test.source'}} %0 = "test.source"() : () -> !torch.vtensor<[],f32> - %1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor + %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor return %1 : tensor } // ----- func @unable_to_convert_lone_tensor_load(%arg0: tensor) { - %0 = torch.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[],f32> + %0 = torch_c.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[],f32> // expected-error @+1 {{failed to legalize operation 'test.sink'}} "test.sink"(%0) : (!torch.vtensor<[],f32>) -> () return diff --git a/test/Dialect/Torch/func-backend-type-conversion.mlir b/test/Dialect/TorchConversion/func-backend-type-conversion.mlir similarity index 75% rename from test/Dialect/Torch/func-backend-type-conversion.mlir rename to test/Dialect/TorchConversion/func-backend-type-conversion.mlir index 85c0a95da..5f036ba15 100644 --- a/test/Dialect/Torch/func-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/func-backend-type-conversion.mlir @@ -5,8 +5,8 @@ // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { -// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> -// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor // CHECK: return %[[MEMREF]] : tensor func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { return %arg0 : !torch.vtensor<[],f32> @@ -14,12 +14,12 @@ func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-LABEL: func @block_arguments( // CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { -// CHECK: %[[T1:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> -// CHECK: %[[M1:.*]] = torch.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[M1:.*]] = torch_c.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor // CHECK: br ^bb1(%[[M1]] : tensor) // CHECK: ^bb1(%[[BBARG:.*]]: tensor): -// CHECK: %[[T2:.*]] = torch.from_builtin_tensor %[[BBARG]] : tensor -> !torch.vtensor<[],f32> -// CHECK: %[[M2:.*]] = torch.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[BBARG]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[M2:.*]] = torch_c.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor // CHECK: return %[[M2]] : tensor func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { br ^bb1(%arg0: !torch.vtensor<[],f32>) @@ -38,8 +38,8 @@ func @call_source() -> !torch.vtensor<[],f32> { } // CHECK-LABEL: func @call_sink( // CHECK-SAME: %[[ARG:.*]]: tensor) { -// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> -// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor // CHECK: call @sink(%[[MEMREF]]) : (tensor) -> () // CHECK: return func private @sink(!torch.vtensor<[],f32>) @@ -50,7 +50,7 @@ func @call_sink(%arg0: !torch.vtensor<[],f32>) { // CHECK-LABEL: func @unconverted_op_in_body() -> tensor { // CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32> -// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor // CHECK: return %[[MEMREF]] : tensor func @unconverted_op_in_body() -> !torch.vtensor<[],f32> { %0 = "test.source"() : () -> !torch.vtensor<[],f32> @@ -98,8 +98,8 @@ func @bwhile(%arg0: i64, %arg1: i64) -> i64 { // CHECK-LABEL: func @identity$torch.bool( // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { -// CHECK: %[[TORCH_BOOL:.*]] = torch.from_i1 %[[ARG]] -// CHECK: %[[I1:.*]] = torch.to_i1 %[[TORCH_BOOL]] +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[ARG]] +// CHECK: %[[I1:.*]] = torch_c.to_i1 %[[TORCH_BOOL]] // CHECK: return %[[I1]] : i1 func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool { return %arg0 : !torch.bool @@ -107,8 +107,8 @@ func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool { // CHECK-LABEL: func @identity$torch.int( // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { -// CHECK: %[[TORCH_INT:.*]] = torch.from_i64 %[[ARG]] -// CHECK: %[[I64:.*]] = torch.to_i64 %[[TORCH_INT]] +// CHECK: %[[TORCH_INT:.*]] = torch_c.from_i64 %[[ARG]] +// CHECK: %[[I64:.*]] = torch_c.to_i64 %[[TORCH_INT]] // CHECK: return %[[I64]] : i64 func @identity$torch.int(%arg0: !torch.int) -> !torch.int { return %arg0 : !torch.int @@ -116,8 +116,8 @@ func @identity$torch.int(%arg0: !torch.int) -> !torch.int { // CHECK-LABEL: func @identity$torch.float( // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { -// CHECK: %[[TORCH_FLOAT:.*]] = torch.from_f64 %[[ARG]] -// CHECK: %[[F64:.*]] = torch.to_f64 %[[TORCH_FLOAT]] +// CHECK: %[[TORCH_FLOAT:.*]] = torch_c.from_f64 %[[ARG]] +// CHECK: %[[F64:.*]] = torch_c.to_f64 %[[TORCH_FLOAT]] // CHECK: return %[[F64]] : f64 func @identity$torch.float(%arg0: !torch.float) -> !torch.float { return %arg0 : !torch.float diff --git a/test/Dialect/TorchConversion/ops.mlir b/test/Dialect/TorchConversion/ops.mlir new file mode 100644 index 000000000..ad0ff9459 --- /dev/null +++ b/test/Dialect/TorchConversion/ops.mlir @@ -0,0 +1,14 @@ +// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s + +// CHECK-LABEL: func @builtin_tensor_interop( +func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) { + // CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> + %0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> + // CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8> + %1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8> + // CHECK: torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32> + %2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32> + // CHECK: torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8> + %3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8> + return +} diff --git a/test/Dialect/Torch/verify-invariants-before-backend-lowering.mlir b/test/Dialect/TorchConversion/verify-invariants-before-backend-lowering.mlir similarity index 100% rename from test/Dialect/Torch/verify-invariants-before-backend-lowering.mlir rename to test/Dialect/TorchConversion/verify-invariants-before-backend-lowering.mlir