mirror of https://github.com/llvm/torch-mlir
Add TorchToIREE and factor out TorchConversion dialect.
This converts a basic list op (torch.prim.ListConstruct) to the IREE dialect. ``` def forward(self, x: float): return [x, x] ``` turns into: ``` builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> { %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float> return %0 : !torch.list<!torch.float> } ``` which turns into: ``` builtin.func @forward(%arg0: f64) -> !iree.list<f64> { %c1 = constant 1 : index %c0 = constant 0 : index %c2 = constant 2 : index %0 = iree.list.create %c2 : !iree.list<f64> iree.list.set %0[%c0], %arg0 : !iree.list<f64>, f64 iree.list.set %0[%c1], %arg0 : !iree.list<f64>, f64 return %0 : !iree.list<f64> } ``` As part of doing this, I realized that it was time to formalize the IR form that we reach right before running TorchTo{Linalg,Std,...}. We now call it the "Torch backend contract". We then lower the "Torch backend contract" to the "npcomp backend contract", which involves the new TorchConversion (`torch_c`) dialect, which holds ops that need to operate on both the npcomp backend types (e.g. builtin tensors, i1, IREE list, etc.) and the `!torch` types. This made more sense, as I realized that if I didn't factor out `torch_c` then the Torch dialect would have a dependency on IREE dialect (we previously didn't notice this was an issue because we only depended on `builtin` types), which seemed wrong to me. Recommended review order: - TorchToIREE.cpp / `TorchToIREE/basic.mlir` - Look at the new structure of createTorchScriptToNpcompBackendPipeline. It now lives in TorchConversion/Transforms/Passes.cpp and cleanly calls into `Torch::createTorchScriptToTorchBackendPipeline` for the frontend lowering to the Torch backend contract. - Mechanical change extracting `torch_c.{to,from}_{i1,i64,f64,builtin_tensor,iree_list}` into a new TorchConversion dialect, and a few passes specific to the lowering from the Torch backend contract to the npcomp backend contract. - Minor fixes to TorchToLinalg.cpp to use unconverted operands (now that we convert lists as part of operand materialization, we need to use the original operands). Also added test for AtenMaxPool2dOp and fixed m_TorchConstantIntList. - TmpDeleteDeadIREELists pass. Temporary pass for deleting dead IREE lists that are created as part of operand materialization for conv/max pool/avg pool ops in TorchToLinalg.pull/283/head
parent
85ff8b692b
commit
cab8d922ec
|
@ -128,6 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||||
"extension = '${PYTHON_MODULE_EXTENSION}")
|
"extension = '${PYTHON_MODULE_EXTENSION}")
|
||||||
|
|
||||||
|
# Include the iree-dialects external project.
|
||||||
|
set(LLVM_EXTERNAL_PROJECTS "iree-dialects")
|
||||||
|
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
|
||||||
|
|
||||||
# LLVM configuration.
|
# LLVM configuration.
|
||||||
message(STATUS "*** ADDING LLVM ***")
|
message(STATUS "*** ADDING LLVM ***")
|
||||||
add_subdirectory(
|
add_subdirectory(
|
||||||
|
@ -177,6 +181,8 @@ include_directories(${LLVM_INCLUDE_DIRS})
|
||||||
include_directories(${MLIR_INCLUDE_DIRS})
|
include_directories(${MLIR_INCLUDE_DIRS})
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
|
||||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
set(NPCOMP_TABLEGEN_ARGS "")
|
set(NPCOMP_TABLEGEN_ARGS "")
|
||||||
|
|
|
@ -22,7 +22,7 @@ using namespace mlir::iree;
|
||||||
void IREEDialect::initialize() {
|
void IREEDialect::initialize() {
|
||||||
addTypes<
|
addTypes<
|
||||||
#define GET_TYPEDEF_LIST
|
#define GET_TYPEDEF_LIST
|
||||||
#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
|
#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
|
|
@ -104,6 +104,14 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()";
|
let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToIREE : Pass<"convert-torch-to-iree", "FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to IREE ops";
|
||||||
|
let description = [{
|
||||||
|
TODO
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::NPCOMP::createConvertTorchToIREEPass()";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Basicpy conversions
|
// Basicpy conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
|
||||||
|
#define NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToIREEPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
|
|
@ -3,3 +3,4 @@ add_subdirectory(Numpy)
|
||||||
add_subdirectory(Refback)
|
add_subdirectory(Refback)
|
||||||
add_subdirectory(Refbackrt)
|
add_subdirectory(Refbackrt)
|
||||||
add_subdirectory(Torch)
|
add_subdirectory(Torch)
|
||||||
|
add_subdirectory(TorchConversion)
|
||||||
|
|
|
@ -19,7 +19,18 @@ def Torch_Dialect : Dialect {
|
||||||
|
|
||||||
This dialect maintains a fairly isomorphic representation with TorchScript.
|
This dialect maintains a fairly isomorphic representation with TorchScript.
|
||||||
|
|
||||||
TODO: Add more detail here.
|
This dialect also provides transforms that lower it to the
|
||||||
|
"Torch backend contract", which is an IR form that we present to
|
||||||
|
later conversions, such as conversion to the npcomp backend contract.
|
||||||
|
The Torch backend contract significantly simplifies the IR representation
|
||||||
|
and puts it in a form easier for later lowering to work on. Specifically:
|
||||||
|
- The TorchScript object graph has been flattened to a list of globals (see
|
||||||
|
the GlobalizeObjectGraph tranformation).
|
||||||
|
- Most of the operations have been changed to operate on value-semantic
|
||||||
|
tensors (see MaximizeValueSemantics)
|
||||||
|
- The number of op variants have been reduced (see ReduceOpVariants)
|
||||||
|
- Tensor sizes have been analyzed and static ranks inferred where possible
|
||||||
|
and propagated throughout the program.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasRegionArgAttrVerify = 1;
|
let hasRegionArgAttrVerify = 1;
|
||||||
|
|
|
@ -87,15 +87,16 @@ struct torch_list_construct_op_binder {
|
||||||
: bind_values(bvs) {}
|
: bind_values(bvs) {}
|
||||||
|
|
||||||
bool match(Operation *op) {
|
bool match(Operation *op) {
|
||||||
if (auto constantNums = dyn_cast<Torch::PrimListConstructOp>(op)) {
|
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||||
for (Value value : constantNums.elements()) {
|
if (!listConstruct)
|
||||||
|
return false;
|
||||||
|
for (Value value : listConstruct.elements()) {
|
||||||
int64_t num;
|
int64_t num;
|
||||||
if (matchPattern(value, m_TorchConstantInt(&num)))
|
if (matchPattern(value, m_TorchConstantInt(&num)))
|
||||||
bind_values.push_back(num);
|
bind_values.push_back(num);
|
||||||
else
|
else
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -611,156 +611,6 @@ def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Conversions to builtin types.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def Torch_ToBuiltinTensorOp : Torch_Op<"to_builtin_tensor", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
|
||||||
let description = [{
|
|
||||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
|
||||||
between value-semantic and non-value-semantic types.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_ValueTensorType:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTensor:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict `:` type($operand) `->` type($result)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_FromBuiltinTensorOp : Torch_Op<"from_builtin_tensor", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `tensor` to a `!torch.vtensor`";
|
|
||||||
let description = [{
|
|
||||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
|
||||||
between value-semantic and non-value-semantic types.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTensor:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_ValueTensorType:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict `:` type($operand) `->` type($result)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_ToI1Op : Torch_Op<"to_i1", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `!torch.bool` to an `i1`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_BoolType:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
I1:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_FromI1Op : Torch_Op<"from_i1", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert an `i1` to a `!torch.bool`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
I1:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_BoolType:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_ToI64Op : Torch_Op<"to_i64", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `!torch.int` to an `i64`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_IntType:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
I64:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_FromI64Op : Torch_Op<"from_i64", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert an `i64` to a `!torch.int`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
I64:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_IntType:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_ToF64Op : Torch_Op<"to_f64", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `!torch.float` to an `f64`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_FloatType:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
F64:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_FromF64Op : Torch_Op<"from_f64", [
|
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert an `f64` to a `!torch.float`";
|
|
||||||
let description = [{
|
|
||||||
This op is primarily useful as a materialization during dialect conversion.
|
|
||||||
}];
|
|
||||||
let arguments = (ins
|
|
||||||
F64:$operand
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_FloatType:$result
|
|
||||||
);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$operand attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Additional ops used to model TorchScript's Graph's / Node's.
|
// Additional ops used to model TorchScript's Graph's / Node's.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -31,16 +31,14 @@ struct TorchLoweringPipelineOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
/// TorchScript import into the form expected by torch-verify-backend-contract.
|
||||||
void createLowerObjectGraphPipeline(
|
void createTorchScriptToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
||||||
/// with the torch and aten dialects and mutable arrays and converts it to
|
/// with the torch and aten dialects and mutable arrays and converts it to
|
||||||
/// the form required by npcomp-verify-backend-contract, in particular
|
/// the form required by torch-verify-backend-contract.
|
||||||
/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to
|
void createGlobalizedModuleToTorchBackendPipeline(
|
||||||
/// linalg, converting torch.prim.* ops to elementary math operations.
|
|
||||||
void createLowerToNpcompBackendPipeline(
|
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
@ -55,14 +53,6 @@ std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
createVerifyInvariantsBeforeBackendLoweringPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
|
||||||
createFinalizingBackendTypeConversionPass();
|
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
|
|
|
@ -214,40 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def VerifyInvariantsBeforeBackendLowering
|
|
||||||
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
|
|
||||||
let summary = "Verify invariants required by backend lowering";
|
|
||||||
let constructor =
|
|
||||||
"mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass()";
|
|
||||||
let description = [{
|
|
||||||
This pass checks any invariants needed by the process of lowering the
|
|
||||||
`torch` dialect to the npcomp backend contract.
|
|
||||||
|
|
||||||
The most important invariant is that all tensors should be ranked and have
|
|
||||||
a known dtype. It is useful to catch this early because it usually
|
|
||||||
represents a simple bug in RefineTypes, but can manifest as many different
|
|
||||||
kinds of obscure symptoms during lowering.
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
|
|
||||||
let summary = "Convert functions to operate on builtin tensors";
|
|
||||||
let constructor = "mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass()";
|
|
||||||
let description = [{
|
|
||||||
Partial type conversion pass analogous in scope to the upstream
|
|
||||||
`func-bufferize` pass. See details there.
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def FinalizingBackendTypeConversion
|
|
||||||
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
|
|
||||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
|
||||||
let constructor =
|
|
||||||
"mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass()";
|
|
||||||
let description = [{
|
|
||||||
Analogous in scope to the upstream `finalizing-bufferize` pass.
|
|
||||||
See details there.
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // NPCOMP_TORCH_PASSES
|
#endif // NPCOMP_TORCH_PASSES
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
|
@ -0,0 +1,15 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TorchConversionOps.td)
|
||||||
|
mlir_tablegen(TorchConversionOps.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(TorchConversionOps.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(TorchConversionDialect.h.inc -gen-dialect-decls -dialect=torch_c)
|
||||||
|
mlir_tablegen(TorchConversionDialect.cpp.inc -gen-dialect-defs -dialect=torch_c)
|
||||||
|
add_public_tablegen_target(MLIRTorchConversionOpsIncGen)
|
||||||
|
add_dependencies(mlir-headers MLIRTorchConversionOpsIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TorchConversionTypes.td)
|
||||||
|
mlir_tablegen(TorchConversionTypes.h.inc -gen-typedef-decls)
|
||||||
|
mlir_tablegen(TorchConversionTypes.cpp.inc -gen-typedef-defs)
|
||||||
|
add_public_tablegen_target(MLIRTorchConversionTypesIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(TorchConversionDialect TorchConversionDialect TorchConversion/ -gen-dialect-doc)
|
||||||
|
add_mlir_doc(TorchConversionOps TorchConversionOps TorchConversion/ -gen-op-doc)
|
|
@ -0,0 +1,29 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHCONVERSION_BASE
|
||||||
|
#define TORCHCONVERSION_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def TorchConversion_Dialect : Dialect {
|
||||||
|
// `torch_conversion` is too verbose.
|
||||||
|
let name = "torch_c";
|
||||||
|
let cppNamespace = "::mlir::NPCOMP::TorchConversion";
|
||||||
|
let description = [{
|
||||||
|
This dialect contains ops and transforms for converting from the Torch
|
||||||
|
backend contract to the npcomp backend contract.
|
||||||
|
|
||||||
|
This mainly consists of converting ops and types from `torch` dialect
|
||||||
|
to the mix of dialects of the npcomp backend contract, such as tensor
|
||||||
|
ops being converted linalg-on-tensors, lists being converted to IREE lists,
|
||||||
|
and !torch.float being converted to `f64`.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCHCONVERSION_BASE
|
|
@ -0,0 +1,16 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H
|
||||||
|
#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc"
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
|
||||||
|
#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
|
||||||
|
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
|
|
@ -0,0 +1,207 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHCONVERSION_OPS
|
||||||
|
#define TORCHCONVERSION_OPS
|
||||||
|
|
||||||
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
||||||
|
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||||
|
include "iree-dialects/Dialect/IREE/IREEDialect.td"
|
||||||
|
|
||||||
|
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
|
: Op<TorchConversion_Dialect, mnemonic, traits> {
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Conversions to backend types.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
||||||
|
let description = [{
|
||||||
|
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||||
|
between value-semantic and non-value-semantic types.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_ValueTensorType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTensor:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `tensor` to a `!torch.vtensor`";
|
||||||
|
let description = [{
|
||||||
|
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||||
|
between value-semantic and non-value-semantic types.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTensor:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_ValueTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!torch.bool` to an `i1`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_BoolType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
I1:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert an `i1` to a `!torch.bool`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
I1:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!torch.int` to an `i64`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_IntType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
I64:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert an `i64` to a `!torch.int`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
I64:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_IntType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!torch.float` to an `f64`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
F64:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||||
|
]> {
|
||||||
|
let summary = "Convert an `f64` to a `!torch.float`";
|
||||||
|
let description = [{
|
||||||
|
This op is primarily useful as a materialization during dialect conversion.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
F64:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_FloatType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Verify the element types match.
|
||||||
|
def TorchConversion_ToIREEListOp : TorchConversion_Op<"to_iree_list", [
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!torch.list` to a `!iree.list`";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_ListType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
IREE_ListType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TorchConversion_FromIREEListOp : TorchConversion_Op<"from_iree_list", [
|
||||||
|
]> {
|
||||||
|
let summary = "Convert a `!iree.list` to a `!torch.list`";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
IREE_ListType:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_ListType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCHCONVERSION_OPS
|
|
@ -6,21 +6,26 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H
|
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H
|
||||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H
|
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H
|
||||||
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace Torch {
|
namespace TorchConversion {
|
||||||
|
|
||||||
|
/// Get the dependent dialects which might be involved in a backend type
|
||||||
|
/// conversion.
|
||||||
|
void getBackendTypeConversionDependentDialects(DialectRegistry ®istry);
|
||||||
|
|
||||||
/// Set up the provided ConversionTarget and TypeConverter for converting
|
/// Set up the provided ConversionTarget and TypeConverter for converting
|
||||||
/// from `torch` dialect types to the types along the npcomp backend boundary
|
/// from `torch` dialect types to the types along the npcomp backend boundary
|
||||||
/// (which currently consist only of builtin types).
|
/// (which currently consist only of builtin types).
|
||||||
void setupBackendTypeConversion(ConversionTarget &target,
|
void setupBackendTypeConversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter);
|
TypeConverter &typeConverter);
|
||||||
} // namespace Torch
|
} // namespace TorchConversion
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H
|
#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(NPCOMPTorchConversionPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes NPCOMPTorchConversionTransforms ./ -gen-pass-doc)
|
|
@ -0,0 +1,44 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||||
|
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace TorchConversion {
|
||||||
|
|
||||||
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
|
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
||||||
|
void createTorchScriptToNpcompBackendPipeline(
|
||||||
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createVerifyInvariantsBeforeBackendLoweringPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
createFinalizingBackendTypeConversionPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createTmpDeleteDeadIREEListsPass();
|
||||||
|
|
||||||
|
} // namespace TorchConversion
|
||||||
|
|
||||||
|
/// Registers all Torch transformation passes.
|
||||||
|
void registerTorchConversionPasses();
|
||||||
|
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,76 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_TORCHCONVERSION_PASSES
|
||||||
|
#define NPCOMP_TORCHCONVERSION_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def VerifyInvariantsBeforeBackendLowering
|
||||||
|
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
|
||||||
|
let summary = "Verify invariants required by backend lowering";
|
||||||
|
let constructor =
|
||||||
|
"mlir::NPCOMP::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()";
|
||||||
|
let description = [{
|
||||||
|
This pass checks any invariants needed by the process of lowering the
|
||||||
|
`torch` dialect to the npcomp backend contract.
|
||||||
|
|
||||||
|
The most important invariant is that all tensors should be ranked and have
|
||||||
|
a known dtype. It is useful to catch this early because it usually
|
||||||
|
represents a simple bug in RefineTypes, but can manifest as many different
|
||||||
|
kinds of obscure symptoms during lowering.
|
||||||
|
|
||||||
|
TODO: This pass should probably be phrased as checking the
|
||||||
|
"torch backend contract" and moved to that dialect once we have more
|
||||||
|
substantial definition definition around what that layer is from an
|
||||||
|
"allowlist" perspective.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
|
||||||
|
let summary = "Convert functions to operate on builtin tensors";
|
||||||
|
let constructor = "mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass()";
|
||||||
|
let description = [{
|
||||||
|
Partial type conversion pass analogous in scope to the upstream
|
||||||
|
`func-bufferize` pass. See details there.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def FinalizingBackendTypeConversion
|
||||||
|
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
|
||||||
|
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||||
|
let constructor =
|
||||||
|
"mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass()";
|
||||||
|
let description = [{
|
||||||
|
Analogous in scope to the upstream `finalizing-bufferize` pass.
|
||||||
|
See details there.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def TmpDeleteDeadIREELists
|
||||||
|
: Pass<"torch-tmp-delete-dead-lists", "FuncOp"> {
|
||||||
|
let summary = "Delete dead !iree.list ops";
|
||||||
|
let constructor =
|
||||||
|
"mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass()";
|
||||||
|
let description = [{
|
||||||
|
Runs a few patterns to delete dead !iree.list ops until IREE can support
|
||||||
|
running them. Currently, these will get materialized as part of conversions
|
||||||
|
for ops like AtenConv2dOp that have list operands, even though they are dead
|
||||||
|
(for those ops, we pattern match a specific case of static constant lists).
|
||||||
|
Currently, this will break execution of those tests because the IREE
|
||||||
|
side of these ops still doesn't work (nor is IREE able to delete them
|
||||||
|
itself).
|
||||||
|
|
||||||
|
TODO: Add support to IREE to run these ops E2E.
|
||||||
|
TODO: Remove this pass once IREE can run them e2e.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif // NPCOMP_TORCHCONVERSION_PASSES
|
|
@ -17,6 +17,7 @@ add_npcomp_library(NPCOMPCommonBackend
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
MLIRMath
|
MLIRMath
|
||||||
|
IREEDialectsIREEDialect
|
||||||
)
|
)
|
||||||
|
|
||||||
mlir_check_all_link_libraries(NPCOMPCommonBackend)
|
mlir_check_all_link_libraries(NPCOMPCommonBackend)
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -32,6 +33,7 @@ class VerifyBackendContractPass
|
||||||
return type;
|
return type;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
});
|
});
|
||||||
|
converter.addConversion([](iree::ListType type) { return type; });
|
||||||
TypeConverter scalarConverter;
|
TypeConverter scalarConverter;
|
||||||
for (TypeConverter *c : {&converter, &scalarConverter}) {
|
for (TypeConverter *c : {&converter, &scalarConverter}) {
|
||||||
c->addConversion([](FloatType type) { return type; });
|
c->addConversion([](FloatType type) { return type; });
|
||||||
|
@ -59,8 +61,7 @@ class VerifyBackendContractPass
|
||||||
// Tensor operations should go through linalg and the tensor dialect.
|
// Tensor operations should go through linalg and the tensor dialect.
|
||||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
|
||||||
// DimOp is used to query tensor sizes.
|
target.addDynamicallyLegalDialect<iree::IREEDialect>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalOp<tensor::DimOp>(opHasLegalTypes);
|
|
||||||
|
|
||||||
// AssertOp is used to terminate the program for error guards.
|
// AssertOp is used to terminate the program for error guards.
|
||||||
target.addLegalOp<AssertOp>();
|
target.addLegalOp<AssertOp>();
|
||||||
|
|
|
@ -32,6 +32,7 @@ add_npcomp_library(NPCOMPInitAll
|
||||||
NPCOMPRefBackend
|
NPCOMPRefBackend
|
||||||
NPCOMPRefbackDialect
|
NPCOMPRefbackDialect
|
||||||
NPCOMPTorchDialect
|
NPCOMPTorchDialect
|
||||||
|
NPCOMPTorchConversionDialect
|
||||||
NPCOMPRefbackrtDialect
|
NPCOMPRefbackrtDialect
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPBasicpyPasses
|
NPCOMPBasicpyPasses
|
||||||
|
@ -39,6 +40,7 @@ add_npcomp_library(NPCOMPInitAll
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
NPCOMPNumpyPasses
|
NPCOMPNumpyPasses
|
||||||
NPCOMPTypingPasses
|
NPCOMPTypingPasses
|
||||||
|
IREEDialectsIREEDialect
|
||||||
|
|
||||||
# TODO: We shouldn't need npcomp_conversion_libs here, but we have
|
# TODO: We shouldn't need npcomp_conversion_libs here, but we have
|
||||||
# some dialect transform libraries accumulating into that property.
|
# some dialect transform libraries accumulating into that property.
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
add_subdirectory(TorchToIREE)
|
||||||
add_subdirectory(TorchToLinalg)
|
add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToStd)
|
add_subdirectory(TorchToStd)
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "npcomp/Conversion/Passes.h"
|
#include "npcomp/Conversion/Passes.h"
|
||||||
|
|
||||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||||
|
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
|
||||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPTorchToIREE
|
||||||
|
TorchToIREE.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToIREE
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
NPCOMPTorchDialect
|
||||||
|
MLIRStandard
|
||||||
|
IREEDialectsIREEDialect
|
||||||
|
)
|
|
@ -0,0 +1,89 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// The patterns
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertPrimListConstructOp
|
||||||
|
: public OpConversionPattern<PrimListConstructOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(PrimListConstructOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto type = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto capacity =
|
||||||
|
rewriter.create<ConstantIndexOp>(op.getLoc(), op->getNumOperands());
|
||||||
|
auto ireeList =
|
||||||
|
rewriter.replaceOpWithNewOp<iree::ListCreateOp>(op, type, capacity);
|
||||||
|
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||||
|
auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||||
|
rewriter.create<iree::ListSetOp>(op.getLoc(), ireeList, index,
|
||||||
|
operands[i]);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// The pass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchToIREE : public ConvertTorchToIREEBase<ConvertTorchToIREE> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<StandardOpsDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<iree::IREEDialect>();
|
||||||
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
|
patterns.add<ConvertPrimListConstructOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<PrimListConstructOp>();
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createConvertTorchToIREEPass() {
|
||||||
|
return std::make_unique<ConvertTorchToIREE>();
|
||||||
|
}
|
|
@ -16,7 +16,8 @@
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -63,7 +64,7 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
// to end. Constant values can be be extracted directly and non constant
|
// to end. Constant values can be be extracted directly and non constant
|
||||||
// list values are not supported.
|
// list values are not supported.
|
||||||
// TODO: loose this constraint when properly support list type
|
// TODO: loose this constraint when properly support list type
|
||||||
static bool isConstantIntListMatching(Value &value,
|
static bool isConstantIntListMatching(Value value,
|
||||||
llvm::SmallVectorImpl<int64_t> &expects) {
|
llvm::SmallVectorImpl<int64_t> &expects) {
|
||||||
llvm::SmallVector<int64_t> intValues;
|
llvm::SmallVector<int64_t> intValues;
|
||||||
if (!matchPattern(value, m_TorchConstantIntList(intValues)))
|
if (!matchPattern(value, m_TorchConstantIntList(intValues)))
|
||||||
|
@ -171,7 +172,6 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
AtenAdaptiveAvgPool2dOp::Adaptor adaptor(operands);
|
AtenAdaptiveAvgPool2dOp::Adaptor adaptor(operands);
|
||||||
Value outputSize = adaptor.output_size();
|
|
||||||
Value input = adaptor.self(); /* in form of N*C*H*W */
|
Value input = adaptor.self(); /* in form of N*C*H*W */
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
Type elementType = inputType.getElementType();
|
Type elementType = inputType.getElementType();
|
||||||
|
@ -183,7 +183,10 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "input should be rank 4");
|
return rewriter.notifyMatchFailure(op, "input should be rank 4");
|
||||||
|
|
||||||
SmallVector<int64_t, 2> expects{1, 1};
|
SmallVector<int64_t, 2> expects{1, 1};
|
||||||
if (!isConstantIntListMatching(outputSize, expects))
|
// Pattern match against the op's original operands, because otherwise we
|
||||||
|
// will get the lowered version of the operands which is harder to pattern
|
||||||
|
// match.
|
||||||
|
if (!isConstantIntListMatching(op.output_size(), expects))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support output_size with H and W both equal to constant 1");
|
op, "only support output_size with H and W both equal to constant 1");
|
||||||
|
|
||||||
|
@ -269,9 +272,6 @@ public:
|
||||||
AtenConv2dOp::Adaptor adaptor(operands);
|
AtenConv2dOp::Adaptor adaptor(operands);
|
||||||
Value input = adaptor.input(); /* in form of N*C*H*W */
|
Value input = adaptor.input(); /* in form of N*C*H*W */
|
||||||
Value weight = adaptor.weight(); /* in form of F*C*H*W */
|
Value weight = adaptor.weight(); /* in form of F*C*H*W */
|
||||||
Value padding = adaptor.padding();
|
|
||||||
Value stride = adaptor.stride();
|
|
||||||
Value dilation = adaptor.dilation();
|
|
||||||
Value groups = adaptor.groups();
|
Value groups = adaptor.groups();
|
||||||
|
|
||||||
Type elementType =
|
Type elementType =
|
||||||
|
@ -291,18 +291,21 @@ public:
|
||||||
Value weightH = getDimOp(rewriter, loc, weight, 2);
|
Value weightH = getDimOp(rewriter, loc, weight, 2);
|
||||||
Value weightW = getDimOp(rewriter, loc, weight, 3);
|
Value weightW = getDimOp(rewriter, loc, weight, 3);
|
||||||
|
|
||||||
|
// Pattern match against the op's original operands, because otherwise we
|
||||||
|
// will get the lowered version of the operands which is harder to pattern
|
||||||
|
// match.
|
||||||
llvm::SmallVector<int64_t> paddingInts;
|
llvm::SmallVector<int64_t> paddingInts;
|
||||||
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) {
|
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support constant padding values");
|
op, "only support constant padding values");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 2> strideInts;
|
llvm::SmallVector<int64_t, 2> strideInts;
|
||||||
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
|
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int strides");
|
"only support constant int strides");
|
||||||
llvm::SmallVector<int64_t, 2> dilationInts;
|
llvm::SmallVector<int64_t, 2> dilationInts;
|
||||||
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
|
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int dilations");
|
"only support constant int dilations");
|
||||||
if (!op.bias().getType().isa<Torch::NoneType>())
|
if (!op.bias().getType().isa<Torch::NoneType>())
|
||||||
|
@ -905,30 +908,29 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
AtenMaxPool2dOp::Adaptor adaptor(operands);
|
AtenMaxPool2dOp::Adaptor adaptor(operands);
|
||||||
Value self = adaptor.self();
|
Value self = adaptor.self();
|
||||||
Value kernelSize = adaptor.kernel_size();
|
|
||||||
Value stride = adaptor.stride();
|
|
||||||
Value padding = adaptor.padding();
|
|
||||||
Value dilation = adaptor.dilation();
|
|
||||||
Value ceilMode = adaptor.ceil_mode();
|
Value ceilMode = adaptor.ceil_mode();
|
||||||
|
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||||
if (!elementType.isa<mlir::FloatType>())
|
if (!elementType.isa<mlir::FloatType>())
|
||||||
op.emitError("unimplemented: non-floating point type");
|
op.emitError("unimplemented: non-floating point type");
|
||||||
|
|
||||||
|
// Pattern match against the op's original operands, because otherwise we
|
||||||
|
// will get the lowered version of the operands which is harder to pattern
|
||||||
|
// match.
|
||||||
llvm::SmallVector<int64_t, 2> strideInts;
|
llvm::SmallVector<int64_t, 2> strideInts;
|
||||||
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
|
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int strides");
|
"only support constant int strides");
|
||||||
llvm::SmallVector<int64_t, 2> dilationInts;
|
llvm::SmallVector<int64_t, 2> dilationInts;
|
||||||
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
|
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int dilations");
|
"only support constant int dilations");
|
||||||
llvm::SmallVector<int64_t, 2> paddingInts;
|
llvm::SmallVector<int64_t, 2> paddingInts;
|
||||||
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts)))
|
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int paddings");
|
"only support constant int paddings");
|
||||||
llvm::SmallVector<int64_t, 2> kernelSizeInts;
|
llvm::SmallVector<int64_t, 2> kernelSizeInts;
|
||||||
if (!matchPattern(kernelSize, m_TorchConstantIntList(kernelSizeInts)))
|
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
|
||||||
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
|
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
|
||||||
|
|
||||||
Value falseValue = rewriter.create<ConstantOp>(
|
Value falseValue = rewriter.create<ConstantOp>(
|
||||||
|
@ -1113,6 +1115,7 @@ public:
|
||||||
registry.insert<math::MathDialect>();
|
registry.insert<math::MathDialect>();
|
||||||
registry.insert<StandardOpsDialect>();
|
registry.insert<StandardOpsDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -1123,7 +1126,7 @@ public:
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
target.addIllegalOp<AtenMmOp>();
|
target.addIllegalOp<AtenMmOp>();
|
||||||
|
|
|
@ -16,4 +16,5 @@ add_npcomp_conversion_library(NPCOMPTorchToSCF
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
NPCOMPTorchDialect
|
NPCOMPTorchDialect
|
||||||
|
NPCOMPTorchConversionDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,8 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -63,6 +64,7 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<scf::SCFDialect>();
|
registry.insert<scf::SCFDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -72,7 +74,7 @@ public:
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
target.addIllegalOp<PrimIfOp>();
|
target.addIllegalOp<PrimIfOp>();
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -93,6 +94,7 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<StandardOpsDialect>();
|
registry.insert<StandardOpsDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -102,7 +104,7 @@ public:
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
target.addIllegalOp<AtenDimOp>();
|
target.addIllegalOp<AtenDimOp>();
|
||||||
|
|
|
@ -3,3 +3,4 @@ add_subdirectory(Numpy)
|
||||||
add_subdirectory(Refback)
|
add_subdirectory(Refback)
|
||||||
add_subdirectory(Refbackrt)
|
add_subdirectory(Refbackrt)
|
||||||
add_subdirectory(Torch)
|
add_subdirectory(Torch)
|
||||||
|
add_subdirectory(TorchConversion)
|
||||||
|
|
|
@ -688,35 +688,6 @@ void CopyToValueTensorOp::getEffects(
|
||||||
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ToBuiltinTensorOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
||||||
auto resultType =
|
|
||||||
operands[0].getType().cast<ValueTensorType>().toBuiltinTensor();
|
|
||||||
if (!resultType)
|
|
||||||
return failure();
|
|
||||||
inferredReturnTypes.push_back(resultType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// FromBuiltinTensorOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
|
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
||||||
inferredReturnTypes.push_back(
|
|
||||||
ValueTensorType::getFromShaped(operands[0].getType().cast<TensorType>()));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantNoneOp
|
// ConstantNoneOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
AdjustCallingConventions.cpp
|
AdjustCallingConventions.cpp
|
||||||
BackendTypeConversion.cpp
|
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
InlineGlobalSlots.cpp
|
InlineGlobalSlots.cpp
|
||||||
|
@ -9,7 +8,6 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
ReduceOpVariants.cpp
|
ReduceOpVariants.cpp
|
||||||
RefinePublicReturn.cpp
|
RefinePublicReturn.cpp
|
||||||
RefineTypes.cpp
|
RefineTypes.cpp
|
||||||
VerifyInvariantsBeforeBackendLowering.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
||||||
|
@ -24,7 +22,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
NPCOMPTorchDialect
|
NPCOMPTorchDialect
|
||||||
NPCOMPTorchToLinalg
|
|
||||||
NPCOMPInterfaces
|
NPCOMPInterfaces
|
||||||
MLIRMemRefTransforms
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,15 +7,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
|
||||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "npcomp/Backend/Common/Passes.h"
|
|
||||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
||||||
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
|
||||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
@ -29,16 +22,16 @@ namespace {
|
||||||
void mlir::NPCOMP::registerTorchPasses() {
|
void mlir::NPCOMP::registerTorchPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-to-npcomp-backend-pipeline",
|
"torchscript-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering torch object graph to npcomp backend format.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
mlir::NPCOMP::Torch::createLowerObjectGraphPipeline);
|
mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline);
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-globalized-module-to-npcomp-backend-pipeline",
|
"torch-globalized-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering to npcomp backend form.",
|
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
||||||
mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline);
|
mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(
|
void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||||
// which can contain numerous functions unrelated to the current program,
|
// which can contain numerous functions unrelated to the current program,
|
||||||
|
@ -66,10 +59,10 @@ void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(
|
||||||
// Incorporate user annotations and remove signature Python-isms.
|
// Incorporate user annotations and remove signature Python-isms.
|
||||||
pm.addPass(createAdjustCallingConventionsPass());
|
pm.addPass(createAdjustCallingConventionsPass());
|
||||||
|
|
||||||
createLowerToNpcompBackendPipeline(pm, options);
|
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// General considerations: As a matter of bring-up, we are simultaneously
|
// General considerations: As a matter of bring-up, we are simultaneously
|
||||||
// building out the frontend pipeline and also co-developing the backend
|
// building out the frontend pipeline and also co-developing the backend
|
||||||
|
@ -140,39 +133,5 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
// TODO: VerifyTorchBackendContractPass.
|
||||||
// Lowering ops and the !torch.vtensor type to the npcomp backend contract.
|
|
||||||
//===--------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Check some invariants to catch errors in a clear way.
|
|
||||||
pm.addPass(Torch::createVerifyInvariantsBeforeBackendLoweringPass());
|
|
||||||
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
|
||||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
|
||||||
// and those constants get somewhat obscured by TorchToStd.
|
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
|
||||||
|
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
|
||||||
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
|
|
||||||
|
|
||||||
if (options.optimize) {
|
|
||||||
// Clean up any non-canonical code introduced in our linalg lowering.
|
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
|
||||||
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
|
||||||
// dialect for some reason -- we don't have memrefs at this level).
|
|
||||||
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
|
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finish the type conversion from !torch.vtensor to the builtin tensor type.
|
|
||||||
pm.addPass(createFuncBackendTypeConversionPass());
|
|
||||||
pm.addNestedPass<FuncOp>(createFinalizingBackendTypeConversionPass());
|
|
||||||
|
|
||||||
// Verify that we have lowered to the form that backends expect.
|
|
||||||
// This fails compilation (signalPassFailure) if the IR is not in the
|
|
||||||
// correct form.
|
|
||||||
pm.addPass(CommonBackend::createVerifyBackendContractPass());
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
|
@ -0,0 +1,18 @@
|
||||||
|
add_npcomp_dialect_library(NPCOMPTorchConversionDialect
|
||||||
|
TorchConversionDialect.cpp
|
||||||
|
TorchConversionOps.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRTorchConversionOpsIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRSupport
|
||||||
|
MLIRSideEffectInterfaces
|
||||||
|
)
|
|
@ -0,0 +1,51 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::TorchConversion;
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.cpp.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TorchConversionInlinerInterface : public DialectInlinerInterface {
|
||||||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||||||
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||||
|
BlockAndValueMapping &) const final {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect initialize method.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void TorchConversionDialect::initialize() {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"
|
||||||
|
>();
|
||||||
|
addInterfaces<TorchConversionInlinerInterface>();
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::TorchConversion;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ToBuiltinTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
auto resultType =
|
||||||
|
operands[0].getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
||||||
|
if (!resultType)
|
||||||
|
return failure();
|
||||||
|
inferredReturnTypes.push_back(resultType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FromBuiltinTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
|
||||||
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(Torch::ValueTensorType::getFromShaped(
|
||||||
|
operands[0].getType().cast<TensorType>()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"
|
|
@ -8,24 +8,36 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::NPCOMP::TorchConversion;
|
||||||
|
|
||||||
|
void mlir::NPCOMP::TorchConversion::getBackendTypeConversionDependentDialects(
|
||||||
|
DialectRegistry ®istry) {
|
||||||
|
registry.insert<TorchConversionDialect>();
|
||||||
|
registry.insert<iree::IREEDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type conversion setup.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void
|
static void
|
||||||
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter) {
|
TypeConverter &typeConverter) {
|
||||||
target.addLegalOp<Torch::ToBuiltinTensorOp, Torch::FromBuiltinTensorOp>();
|
target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
|
||||||
|
TorchConversion::FromBuiltinTensorOp>();
|
||||||
typeConverter.addConversion(
|
typeConverter.addConversion(
|
||||||
[](Torch::ValueTensorType type) -> Optional<Type> {
|
[](Torch::ValueTensorType type) -> Optional<Type> {
|
||||||
return type.toBuiltinTensor();
|
return type.toBuiltinTensor();
|
||||||
|
@ -34,10 +46,11 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||||
ValueRange inputs,
|
ValueRange inputs,
|
||||||
Location loc) -> Value {
|
Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<BaseTensorType>());
|
assert(inputs[0].getType().isa<Torch::BaseTensorType>());
|
||||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, ValueTensorType type,
|
auto sourceMaterialization = [](OpBuilder &builder,
|
||||||
|
Torch::ValueTensorType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<TensorType>());
|
assert(inputs[0].getType().isa<TensorType>());
|
||||||
|
@ -49,7 +62,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||||
|
|
||||||
static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter) {
|
TypeConverter &typeConverter) {
|
||||||
target.addLegalOp<Torch::ToI1Op, Torch::FromI1Op>();
|
target.addLegalOp<TorchConversion::ToI1Op, TorchConversion::FromI1Op>();
|
||||||
typeConverter.addConversion([](Torch::BoolType type) -> Optional<Type> {
|
typeConverter.addConversion([](Torch::BoolType type) -> Optional<Type> {
|
||||||
return IntegerType::get(type.getContext(), 1);
|
return IntegerType::get(type.getContext(), 1);
|
||||||
});
|
});
|
||||||
|
@ -75,7 +88,7 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
||||||
|
|
||||||
static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter) {
|
TypeConverter &typeConverter) {
|
||||||
target.addLegalOp<Torch::ToI64Op, Torch::FromI64Op>();
|
target.addLegalOp<TorchConversion::ToI64Op, TorchConversion::FromI64Op>();
|
||||||
typeConverter.addConversion([](Torch::IntType type) -> Optional<Type> {
|
typeConverter.addConversion([](Torch::IntType type) -> Optional<Type> {
|
||||||
return IntegerType::get(type.getContext(), 64);
|
return IntegerType::get(type.getContext(), 64);
|
||||||
});
|
});
|
||||||
|
@ -101,7 +114,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
||||||
|
|
||||||
static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter) {
|
TypeConverter &typeConverter) {
|
||||||
target.addLegalOp<Torch::ToF64Op, Torch::FromF64Op>();
|
target.addLegalOp<TorchConversion::ToF64Op, TorchConversion::FromF64Op>();
|
||||||
typeConverter.addConversion([](Torch::FloatType type) -> Optional<Type> {
|
typeConverter.addConversion([](Torch::FloatType type) -> Optional<Type> {
|
||||||
return Float64Type::get(type.getContext());
|
return Float64Type::get(type.getContext());
|
||||||
});
|
});
|
||||||
|
@ -122,12 +135,38 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
||||||
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::Torch::setupBackendTypeConversion(
|
static void setupTorchListToIREEListConversion(ConversionTarget &target,
|
||||||
|
TypeConverter &typeConverter) {
|
||||||
|
target.addLegalOp<TorchConversion::ToIREEListOp,
|
||||||
|
TorchConversion::FromIREEListOp>();
|
||||||
|
typeConverter.addConversion([&](Torch::ListType type) -> Optional<Type> {
|
||||||
|
return iree::ListType::get(
|
||||||
|
type.getContext(), typeConverter.convertType(type.getContainedType()));
|
||||||
|
});
|
||||||
|
typeConverter.addTargetMaterialization(
|
||||||
|
[](OpBuilder &builder, iree::ListType type, ValueRange inputs,
|
||||||
|
Location loc) -> Optional<Value> {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<Torch::ListType>());
|
||||||
|
return builder.create<ToIREEListOp>(loc, type, inputs[0]).getResult();
|
||||||
|
});
|
||||||
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::ListType type,
|
||||||
|
ValueRange inputs, Location loc) -> Value {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(inputs[0].getType().isa<iree::ListType>());
|
||||||
|
return builder.create<FromIREEListOp>(loc, type, inputs[0]);
|
||||||
|
};
|
||||||
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::NPCOMP::TorchConversion::setupBackendTypeConversion(
|
||||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||||
setupTorchBoolToI1Conversion(target, typeConverter);
|
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||||
setupTorchIntToI64Conversion(target, typeConverter);
|
setupTorchIntToI64Conversion(target, typeConverter);
|
||||||
setupTorchFloatToF64Conversion(target, typeConverter);
|
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||||
|
setupTorchListToIREEListConversion(target, typeConverter);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -139,6 +178,9 @@ struct FuncBackendTypeConversionPass
|
||||||
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
||||||
using FuncBackendTypeConversionBase<
|
using FuncBackendTypeConversionBase<
|
||||||
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
|
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<TorchConversion::TorchConversionDialect>();
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
@ -147,7 +189,7 @@ struct FuncBackendTypeConversionPass
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
populateFuncOpTypeConversionPattern(patterns, typeConverter);
|
populateFuncOpTypeConversionPattern(patterns, typeConverter);
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
|
@ -176,7 +218,7 @@ struct FuncBackendTypeConversionPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass() {
|
mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass() {
|
||||||
return std::make_unique<FuncBackendTypeConversionPass>();
|
return std::make_unique<FuncBackendTypeConversionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,13 +276,13 @@ struct FinalizingBackendTypeConversionPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
// Mark materializations as illegal in this pass (since we are finalizing)
|
// Mark materializations as illegal in this pass (since we are finalizing)
|
||||||
// and add patterns that eliminate them.
|
// and add patterns that eliminate them.
|
||||||
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
|
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
|
||||||
FromI64Op, ToI64Op, FromF64Op, ToF64Op>(target, patterns,
|
FromI64Op, ToI64Op, FromF64Op, ToF64Op, FromIREEListOp,
|
||||||
typeConverter);
|
ToIREEListOp>(target, patterns, typeConverter);
|
||||||
|
|
||||||
// If all result types are legal, and all block arguments are legal, then
|
// If all result types are legal, and all block arguments are legal, then
|
||||||
// all types in the program are legal.
|
// all types in the program are legal.
|
||||||
|
@ -259,6 +301,6 @@ struct FinalizingBackendTypeConversionPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass() {
|
mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
||||||
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
||||||
}
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPTorchConversionPasses
|
||||||
|
BackendTypeConversion.cpp
|
||||||
|
Passes.cpp
|
||||||
|
TmpDeleteDeadIREELists.cpp
|
||||||
|
VerifyInvariantsBeforeBackendLowering.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion/Transforms
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPTorchConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
NPCOMPTorchConversionDialect
|
||||||
|
NPCOMPTorchDialect
|
||||||
|
NPCOMPTorchPasses
|
||||||
|
NPCOMPTorchToIREE
|
||||||
|
NPCOMPTorchToLinalg
|
||||||
|
NPCOMPTorchToStd
|
||||||
|
NPCOMPTorchToSCF
|
||||||
|
NPCOMPInterfaces
|
||||||
|
MLIRMemRefTransforms
|
||||||
|
)
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||||
|
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace TorchConversion {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace TorchConversion
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
|
@ -0,0 +1,97 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "npcomp/Backend/Common/Passes.h"
|
||||||
|
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
|
||||||
|
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
|
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void mlir::NPCOMP::registerTorchConversionPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
|
"torchscript-to-npcomp-backend-pipeline",
|
||||||
|
"Pipeline lowering torch object graph to npcomp backend format.",
|
||||||
|
mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline(
|
||||||
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
|
||||||
|
// Conversion to the npcomp backend contract starts from the Torch backend
|
||||||
|
// contract.
|
||||||
|
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||||
|
|
||||||
|
// Check some invariants to catch errors in a clear way.
|
||||||
|
pm.addPass(
|
||||||
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
|
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||||
|
// and those constants get somewhat obscured by TorchToStd.
|
||||||
|
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||||
|
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||||
|
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||||
|
// Lists and other concepts that don't exist in upstream go through the IREE
|
||||||
|
// dialect, which we treat as an reasonably well designed interim placeholder
|
||||||
|
// for the set of ops that we think makes sense in the npcomp backend
|
||||||
|
// contract. We expect to co-evolve this dialect with npcomp needs, as a lot
|
||||||
|
// of what we are doing here in npcomp is breaking new ground w.r.t.
|
||||||
|
// expressiveness and program generality for tensor compilers.
|
||||||
|
//
|
||||||
|
// We lower lists last because the lowered form is much harder to reason about
|
||||||
|
// than the original form.
|
||||||
|
pm.addNestedPass<FuncOp>(createConvertTorchToIREEPass());
|
||||||
|
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
|
||||||
|
|
||||||
|
if (options.optimize) {
|
||||||
|
// Clean up any non-canonical code introduced above..
|
||||||
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
|
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||||
|
// dialect for some reason -- we don't have memrefs at this level).
|
||||||
|
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish the type conversion from `torch` types to the types of the npcomp
|
||||||
|
// backend contract.
|
||||||
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
// Temporarily delete dead list ops until IREE can run them e2e.
|
||||||
|
// TODO: Remove this pass once IREE can run them e2e.
|
||||||
|
// TODO: Add support to IREE to run these ops E2E.
|
||||||
|
pm.addNestedPass<FuncOp>(TorchConversion::createTmpDeleteDeadIREEListsPass());
|
||||||
|
|
||||||
|
// Verify that we have lowered to the form that npcomp backends expect.
|
||||||
|
// This fails compilation (signalPassFailure) if the IR is not in the
|
||||||
|
// correct form.
|
||||||
|
pm.addPass(CommonBackend::createVerifyBackendContractPass());
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
//===- TmpDeleteDeadIREELists.cpp --------------------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEOps.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::TorchConversion;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class TmpDeleteDeadIREEListsPass
|
||||||
|
: public TmpDeleteDeadIREEListsBase<TmpDeleteDeadIREEListsPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
SmallVector<Operation *> toErase;
|
||||||
|
// Delete lists that are only set (but not read from).
|
||||||
|
// This is created by our lowering for torch.prim.ListConstruct.
|
||||||
|
// Until IREE can run such ops e2e (or delete them itself), we need to
|
||||||
|
// do this cleanup.
|
||||||
|
// TODO: Add support to IREE to run these ops E2E.
|
||||||
|
getOperation().walk([&](iree::ListCreateOp op) {
|
||||||
|
SmallVector<Operation *> deadOps;
|
||||||
|
deadOps.push_back(op);
|
||||||
|
for (auto &use : op.getResult().getUses()) {
|
||||||
|
if (isa<iree::ListSetOp>(use.getOwner())) {
|
||||||
|
deadOps.push_back(use.getOwner());
|
||||||
|
} else {
|
||||||
|
// We can't analyze the list op if it is used by something else.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm::append_range(toErase, deadOps);
|
||||||
|
});
|
||||||
|
for (auto *op : toErase) {
|
||||||
|
op->dropAllDefinedValueUses();
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass() {
|
||||||
|
return std::make_unique<TmpDeleteDeadIREEListsPass>();
|
||||||
|
}
|
|
@ -11,17 +11,18 @@
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::NPCOMP::TorchConversion;
|
||||||
|
|
||||||
static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) {
|
static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) {
|
||||||
// TODO: Make this an allowlist instead of a denylist.
|
// TODO: Make this an allowlist instead of a denylist.
|
||||||
// TODO: Make this stricter.
|
// TODO: Make this stricter.
|
||||||
auto type = v.getType();
|
auto type = v.getType();
|
||||||
if (auto valueTensorType = type.dyn_cast<ValueTensorType>()) {
|
if (auto valueTensorType = type.dyn_cast<Torch::ValueTensorType>()) {
|
||||||
if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes())
|
if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes())
|
||||||
return errorReportOp->emitError()
|
return errorReportOp->emitError()
|
||||||
.append("unsupported by backend lowering: tensor with unknown rank "
|
.append("unsupported by backend lowering: tensor with unknown rank "
|
||||||
|
@ -77,7 +78,7 @@ class VerifyInvariantsBeforeBackendLoweringPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::NPCOMP::TorchConversion::
|
||||||
mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass() {
|
createVerifyInvariantsBeforeBackendLoweringPass() {
|
||||||
return std::make_unique<VerifyInvariantsBeforeBackendLoweringPass>();
|
return std::make_unique<VerifyInvariantsBeforeBackendLoweringPass>();
|
||||||
}
|
}
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "npcomp/InitAll.h"
|
#include "npcomp/InitAll.h"
|
||||||
|
|
||||||
|
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "npcomp/Backend/Common/Passes.h"
|
#include "npcomp/Backend/Common/Passes.h"
|
||||||
#include "npcomp/Backend/IREE/Passes.h"
|
#include "npcomp/Backend/IREE/Passes.h"
|
||||||
|
@ -20,6 +21,8 @@
|
||||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
#include "npcomp/RefBackend/RefBackend.h"
|
#include "npcomp/RefBackend/RefBackend.h"
|
||||||
#include "npcomp/Typing/Transforms/Passes.h"
|
#include "npcomp/Typing/Transforms/Passes.h"
|
||||||
|
|
||||||
|
@ -29,7 +32,9 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
Numpy::NumpyDialect,
|
Numpy::NumpyDialect,
|
||||||
refbackrt::RefbackrtDialect,
|
refbackrt::RefbackrtDialect,
|
||||||
refback::RefbackDialect,
|
refback::RefbackDialect,
|
||||||
mlir::NPCOMP::Torch::TorchDialect>();
|
mlir::NPCOMP::Torch::TorchDialect,
|
||||||
|
mlir::NPCOMP::TorchConversion::TorchConversionDialect,
|
||||||
|
iree::IREEDialect>();
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,6 +44,7 @@ void mlir::NPCOMP::registerAllPasses() {
|
||||||
mlir::NPCOMP::registerBasicpyPasses();
|
mlir::NPCOMP::registerBasicpyPasses();
|
||||||
mlir::NPCOMP::registerNumpyPasses();
|
mlir::NPCOMP::registerNumpyPasses();
|
||||||
mlir::NPCOMP::registerTorchPasses();
|
mlir::NPCOMP::registerTorchPasses();
|
||||||
|
mlir::NPCOMP::registerTorchConversionPasses();
|
||||||
mlir::NPCOMP::registerTypingPasses();
|
mlir::NPCOMP::registerTypingPasses();
|
||||||
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
|
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
|
||||||
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();
|
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
|
||||||
|
// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: builtin.func @forward(
|
||||||
|
// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list<!torch.float> {
|
||||||
|
// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
|
||||||
|
// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[LIST:.*]] = iree.list.create %[[C2]] : !iree.list<f64>
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: iree.list.set %[[LIST]][%[[C0]]], %[[ARG]] : !iree.list<f64>, f64
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: iree.list.set %[[LIST]][%[[C1]]], %[[ALSO_ARG]] : !iree.list<f64>, f64
|
||||||
|
// CHECK: %[[LIST_TORCH:.*]] = torch_c.from_iree_list %[[LIST]] : !iree.list<f64> -> !torch.list<!torch.float>
|
||||||
|
// CHECK: return %[[LIST_TORCH]] : !torch.list<!torch.float>
|
||||||
|
builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> {
|
||||||
|
%0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float>
|
||||||
|
return %0 : !torch.list<!torch.float>
|
||||||
|
}
|
|
@ -3,8 +3,8 @@
|
||||||
// CHECK-LABEL: func @torch.aten.mm$basic(
|
// CHECK-LABEL: func @torch.aten.mm$basic(
|
||||||
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
|
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
||||||
// CHECK: %[[LHS:.*]] = torch.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[RHS:.*]] = torch.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
// CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
@ -20,7 +20,7 @@
|
||||||
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
||||||
// CHECK: %[[RESULT_VTENSOR:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
|
// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
|
||||||
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
|
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
|
||||||
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
|
||||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
|
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @elementwise$unary(
|
// CHECK-LABEL: func @elementwise$unary(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
|
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
|
||||||
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
|
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
|
||||||
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
|
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
|
||||||
|
@ -11,7 +11,7 @@
|
||||||
// CHECK: linalg.yield %[[TANH]] : f32
|
// CHECK: linalg.yield %[[TANH]] : f32
|
||||||
// CHECK: } -> tensor<f32>
|
// CHECK: } -> tensor<f32>
|
||||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
@ -22,8 +22,8 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
|
||||||
// CHECK-LABEL: func @elementwise$binary(
|
// CHECK-LABEL: func @elementwise$binary(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[BUILTIN_ARG0:.*]] = torch.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[BUILTIN_ARG1:.*]] = torch.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
|
// CHECK: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor<?x?xf32>
|
// CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
@ -39,7 +39,7 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
|
||||||
// CHECK: linalg.yield %[[MUL]] : f32
|
// CHECK: linalg.yield %[[MUL]] : f32
|
||||||
// CHECK: } -> tensor<?x?xf32>
|
// CHECK: } -> tensor<?x?xf32>
|
||||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||||
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -61,7 +61,7 @@ func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[BUILTIN_C1:.*]] = torch.to_i64 %[[C1]]
|
// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]]
|
||||||
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
|
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
|
||||||
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
|
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
|
||||||
// CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32
|
// CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
|
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||||
|
|
||||||
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||||
|
@ -21,10 +21,10 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
|
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||||
|
|
||||||
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||||
|
@ -38,10 +38,10 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
|
||||||
|
|
||||||
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
|
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
|
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
|
||||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
|
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||||
|
|
||||||
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -55,10 +55,10 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
|
||||||
|
|
||||||
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
|
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
|
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
|
||||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
|
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
|
||||||
|
|
||||||
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
||||||
|
@ -72,9 +72,9 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
|
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||||
|
|
||||||
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: builtin.func @forward
|
||||||
|
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int5 = torch.constant.int 5
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%int7 = torch.constant.int 7
|
||||||
|
%int8 = torch.constant.int 8
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||||
|
// CHECK: %[[NEUTRAL:.*]] = constant -1.401300e-45 : f32
|
||||||
|
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
|
||||||
|
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%dilation = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
%4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %4 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
|
@ -5,9 +5,9 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
|
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||||
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
|
@ -17,9 +17,9 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtenso
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
|
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||||
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
%int-1 = torch.constant.int -1
|
%int-1 = torch.constant.int -1
|
||||||
|
@ -29,9 +29,9 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !tor
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
|
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
|
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
|
||||||
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
|
@ -41,9 +41,9 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>)
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
|
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
|
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
|
||||||
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
|
||||||
%int-1 = torch.constant.int -1
|
%int-1 = torch.constant.int -1
|
||||||
|
@ -53,9 +53,9 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>)
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
|
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||||
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
|
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
|
||||||
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
|
|
|
@ -4,15 +4,15 @@
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]]
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_0]]
|
||||||
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
|
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]]
|
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_1]]
|
||||||
// CHECK: scf.yield %[[VAL_5]] : i64
|
// CHECK: scf.yield %[[VAL_5]] : i64
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]]
|
// CHECK: %[[VAL_6:.*]] = torch_c.to_i64 %[[VAL_2]]
|
||||||
// CHECK: scf.yield %[[VAL_6]] : i64
|
// CHECK: scf.yield %[[VAL_6]] : i64
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]]
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]]
|
||||||
// CHECK: return %[[VAL_7]] : !torch.int
|
// CHECK: return %[[VAL_7]] : !torch.int
|
||||||
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
|
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
|
@ -31,22 +31,22 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]]
|
// CHECK: %[[VAL_5:.*]] = torch_c.to_i1 %[[VAL_0]]
|
||||||
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
|
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
|
||||||
// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]]
|
// CHECK: %[[VAL_7:.*]] = torch_c.to_i1 %[[VAL_1]]
|
||||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
|
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
|
||||||
// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]]
|
// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_2]]
|
||||||
// CHECK: scf.yield %[[VAL_9]] : i64
|
// CHECK: scf.yield %[[VAL_9]] : i64
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]]
|
// CHECK: %[[VAL_10:.*]] = torch_c.to_i64 %[[VAL_3]]
|
||||||
// CHECK: scf.yield %[[VAL_10]] : i64
|
// CHECK: scf.yield %[[VAL_10]] : i64
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: scf.yield %[[VAL_11:.*]] : i64
|
// CHECK: scf.yield %[[VAL_11:.*]] : i64
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]]
|
// CHECK: %[[VAL_12:.*]] = torch_c.to_i64 %[[VAL_4]]
|
||||||
// CHECK: scf.yield %[[VAL_12]] : i64
|
// CHECK: scf.yield %[[VAL_12]] : i64
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]]
|
// CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]]
|
||||||
// CHECK: return %[[VAL_13]] : !torch.int
|
// CHECK: return %[[VAL_13]] : !torch.int
|
||||||
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
|
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.dim(
|
// CHECK-LABEL: func @torch.aten.dim(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
|
||||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||||
// CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
|
// CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
|
||||||
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK]] : index to i64
|
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK]] : index to i64
|
||||||
// CHECK: %[[RANK_TORCH_INT:.*]] = torch.from_i64 %[[RANK_I64]]
|
// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
|
||||||
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int
|
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int
|
||||||
func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
|
func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
|
||||||
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int
|
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int
|
||||||
|
@ -16,10 +16,10 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
|
||||||
// CHECK-LABEL: func @torch.aten.ne.int(
|
// CHECK-LABEL: func @torch.aten.ne.int(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]]
|
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||||
// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]]
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
// CHECK: %[[CMP:.*]] = cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64
|
// CHECK: %[[CMP:.*]] = cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64
|
||||||
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]]
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
@ -29,10 +29,10 @@ func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
// CHECK-LABEL: func @torch.aten.gt.int(
|
// CHECK-LABEL: func @torch.aten.gt.int(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]]
|
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||||
// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]]
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
// CHECK: %[[CMP:.*]] = cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64
|
// CHECK: %[[CMP:.*]] = cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64
|
||||||
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]]
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
%0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
%0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
@ -41,7 +41,7 @@ func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
|
||||||
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||||
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||||
|
@ -50,7 +50,7 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
|
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
|
||||||
// CHECK: %[[CST:.*]] = constant true
|
// CHECK: %[[CST:.*]] = constant true
|
||||||
// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]]
|
// CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]]
|
||||||
// CHECK: return %[[BOOL]] : !torch.bool
|
// CHECK: return %[[BOOL]] : !torch.bool
|
||||||
func @torch.constant.bool() -> !torch.bool {
|
func @torch.constant.bool() -> !torch.bool {
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
@ -59,7 +59,7 @@ func @torch.constant.bool() -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
|
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
|
||||||
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
|
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
|
||||||
// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]]
|
// CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]]
|
||||||
// CHECK: return %[[FLOAT]] : !torch.float
|
// CHECK: return %[[FLOAT]] : !torch.float
|
||||||
func @torch.constant.float() -> !torch.float {
|
func @torch.constant.float() -> !torch.float {
|
||||||
%float = torch.constant.float 1.000000e+00
|
%float = torch.constant.float 1.000000e+00
|
||||||
|
@ -68,7 +68,7 @@ func @torch.constant.float() -> !torch.float {
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
|
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
|
||||||
// CHECK: %[[CST:.*]] = constant 1 : i64
|
// CHECK: %[[CST:.*]] = constant 1 : i64
|
||||||
// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]]
|
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]]
|
||||||
// CHECK: return %[[INT]] : !torch.int
|
// CHECK: return %[[INT]] : !torch.int
|
||||||
func @torch.constant.int() -> !torch.int {
|
func @torch.constant.int() -> !torch.int {
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
|
|
|
@ -13,19 +13,6 @@ func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) ->
|
||||||
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
|
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @builtin_tensor_interop(
|
|
||||||
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
|
|
||||||
// CHECK: torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
|
||||||
%0 = torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
|
||||||
%1 = torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
|
||||||
// CHECK: torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
|
||||||
%2 = torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
|
||||||
// CHECK: torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
|
||||||
%3 = torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK: @tensor.default() -> !torch.tensor
|
// CHECK: @tensor.default() -> !torch.tensor
|
||||||
func private @tensor.default() -> !torch.tensor
|
func private @tensor.default() -> !torch.tensor
|
||||||
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}
|
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||||
// CHECK: return %[[ARG]] : tensor<f32>
|
// CHECK: return %[[ARG]] : tensor<f32>
|
||||||
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
|
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
return %1 : tensor<f32>
|
return %1 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,8 +19,8 @@ func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
|
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
|
||||||
// CHECK: return %[[ARG]] : i1
|
// CHECK: return %[[ARG]] : i1
|
||||||
func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
|
func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
|
||||||
%0 = torch.from_i1 %arg0
|
%0 = torch_c.from_i1 %arg0
|
||||||
%1 = torch.to_i1 %0
|
%1 = torch_c.to_i1 %0
|
||||||
return %1 : i1
|
return %1 : i1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,8 +28,8 @@ func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
|
||||||
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
||||||
// CHECK: return %[[ARG]] : i64
|
// CHECK: return %[[ARG]] : i64
|
||||||
func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
|
func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
|
||||||
%0 = torch.from_i64 %arg0
|
%0 = torch_c.from_i64 %arg0
|
||||||
%1 = torch.to_i64 %0
|
%1 = torch_c.to_i64 %0
|
||||||
return %1 : i64
|
return %1 : i64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,24 +37,33 @@ func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
|
||||||
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
|
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
|
||||||
// CHECK: return %[[ARG]] : f64
|
// CHECK: return %[[ARG]] : f64
|
||||||
func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
|
func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
|
||||||
%0 = torch.from_f64 %arg0
|
%0 = torch_c.from_f64 %arg0
|
||||||
%1 = torch.to_f64 %0
|
%1 = torch_c.to_f64 %0
|
||||||
return %1 : f64
|
return %1 : f64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @eliminate_materializations$torch.list(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !iree.list<f64>) -> !iree.list<f64> {
|
||||||
|
// CHECK: return %[[ARG]] : !iree.list<f64>
|
||||||
|
func @eliminate_materializations$torch.list(%arg0: !iree.list<f64>) -> !iree.list<f64> {
|
||||||
|
%0 = torch_c.from_iree_list %arg0 : !iree.list<f64> -> !torch.list<!torch.float>
|
||||||
|
%1 = torch_c.to_iree_list %0 : !torch.list<!torch.float> -> !iree.list<f64>
|
||||||
|
return %1 : !iree.list<f64>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
|
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
|
||||||
// expected-error @+1 {{failed to legalize operation 'test.source'}}
|
// expected-error @+1 {{failed to legalize operation 'test.source'}}
|
||||||
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||||
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
return %1 : tensor<f32>
|
return %1 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
|
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
|
||||||
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
|
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
|
||||||
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
||||||
return
|
return
|
|
@ -5,8 +5,8 @@
|
||||||
|
|
||||||
// CHECK-LABEL: func @identity(
|
// CHECK-LABEL: func @identity(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||||
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: return %[[MEMREF]] : tensor<f32>
|
// CHECK: return %[[MEMREF]] : tensor<f32>
|
||||||
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
return %arg0 : !torch.vtensor<[],f32>
|
return %arg0 : !torch.vtensor<[],f32>
|
||||||
|
@ -14,12 +14,12 @@ func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @block_arguments(
|
// CHECK-LABEL: func @block_arguments(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
|
||||||
// CHECK: %[[T1:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[M1:.*]] = torch.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[M1:.*]] = torch_c.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: br ^bb1(%[[M1]] : tensor<f32>)
|
// CHECK: br ^bb1(%[[M1]] : tensor<f32>)
|
||||||
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
|
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[T2:.*]] = torch.from_builtin_tensor %[[BBARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[BBARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[M2:.*]] = torch.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[M2:.*]] = torch_c.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: return %[[M2]] : tensor<f32>
|
// CHECK: return %[[M2]] : tensor<f32>
|
||||||
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||||
br ^bb1(%arg0: !torch.vtensor<[],f32>)
|
br ^bb1(%arg0: !torch.vtensor<[],f32>)
|
||||||
|
@ -38,8 +38,8 @@ func @call_source() -> !torch.vtensor<[],f32> {
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @call_sink(
|
// CHECK-LABEL: func @call_sink(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
|
||||||
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: call @sink(%[[MEMREF]]) : (tensor<f32>) -> ()
|
// CHECK: call @sink(%[[MEMREF]]) : (tensor<f32>) -> ()
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
func private @sink(!torch.vtensor<[],f32>)
|
func private @sink(!torch.vtensor<[],f32>)
|
||||||
|
@ -50,7 +50,7 @@ func @call_sink(%arg0: !torch.vtensor<[],f32>) {
|
||||||
|
|
||||||
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> {
|
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> {
|
||||||
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
|
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||||
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: return %[[MEMREF]] : tensor<f32>
|
// CHECK: return %[[MEMREF]] : tensor<f32>
|
||||||
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
|
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
|
||||||
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
|
||||||
|
@ -98,8 +98,8 @@ func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
|
||||||
|
|
||||||
// CHECK-LABEL: func @identity$torch.bool(
|
// CHECK-LABEL: func @identity$torch.bool(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
|
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
|
||||||
// CHECK: %[[TORCH_BOOL:.*]] = torch.from_i1 %[[ARG]]
|
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[ARG]]
|
||||||
// CHECK: %[[I1:.*]] = torch.to_i1 %[[TORCH_BOOL]]
|
// CHECK: %[[I1:.*]] = torch_c.to_i1 %[[TORCH_BOOL]]
|
||||||
// CHECK: return %[[I1]] : i1
|
// CHECK: return %[[I1]] : i1
|
||||||
func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
|
func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
|
||||||
return %arg0 : !torch.bool
|
return %arg0 : !torch.bool
|
||||||
|
@ -107,8 +107,8 @@ func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
|
||||||
|
|
||||||
// CHECK-LABEL: func @identity$torch.int(
|
// CHECK-LABEL: func @identity$torch.int(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
|
||||||
// CHECK: %[[TORCH_INT:.*]] = torch.from_i64 %[[ARG]]
|
// CHECK: %[[TORCH_INT:.*]] = torch_c.from_i64 %[[ARG]]
|
||||||
// CHECK: %[[I64:.*]] = torch.to_i64 %[[TORCH_INT]]
|
// CHECK: %[[I64:.*]] = torch_c.to_i64 %[[TORCH_INT]]
|
||||||
// CHECK: return %[[I64]] : i64
|
// CHECK: return %[[I64]] : i64
|
||||||
func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
|
func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
|
||||||
return %arg0 : !torch.int
|
return %arg0 : !torch.int
|
||||||
|
@ -116,8 +116,8 @@ func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
|
||||||
|
|
||||||
// CHECK-LABEL: func @identity$torch.float(
|
// CHECK-LABEL: func @identity$torch.float(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
|
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
|
||||||
// CHECK: %[[TORCH_FLOAT:.*]] = torch.from_f64 %[[ARG]]
|
// CHECK: %[[TORCH_FLOAT:.*]] = torch_c.from_f64 %[[ARG]]
|
||||||
// CHECK: %[[F64:.*]] = torch.to_f64 %[[TORCH_FLOAT]]
|
// CHECK: %[[F64:.*]] = torch_c.to_f64 %[[TORCH_FLOAT]]
|
||||||
// CHECK: return %[[F64]] : f64
|
// CHECK: return %[[F64]] : f64
|
||||||
func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
|
func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
|
||||||
return %arg0 : !torch.float
|
return %arg0 : !torch.float
|
|
@ -0,0 +1,14 @@
|
||||||
|
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @builtin_tensor_interop(
|
||||||
|
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
|
||||||
|
// CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
||||||
|
%0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
|
||||||
|
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
||||||
|
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
|
||||||
|
// CHECK: torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||||
|
%2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
|
||||||
|
// CHECK: torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
||||||
|
%3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue