Add TorchToIREE and factor out TorchConversion dialect.

This converts a basic list op (torch.prim.ListConstruct) to the IREE
dialect.

```
    def forward(self, x: float):
            return [x, x]
```

turns into:

```
builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> {
  %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float>
  return %0 : !torch.list<!torch.float>
}
```

which turns into:

```
builtin.func @forward(%arg0: f64) -> !iree.list<f64> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c2 = constant 2 : index
  %0 = iree.list.create %c2 : !iree.list<f64>
  iree.list.set %0[%c0], %arg0 : !iree.list<f64>, f64
  iree.list.set %0[%c1], %arg0 : !iree.list<f64>, f64
  return %0 : !iree.list<f64>
}
```

As part of doing this, I realized that it was time to formalize the IR
form that we reach right before running TorchTo{Linalg,Std,...}. We now
call it the "Torch backend contract". We then lower the "Torch backend
contract" to the "npcomp backend contract", which involves the new
TorchConversion (`torch_c`) dialect, which holds ops that need to
operate on both the npcomp backend types (e.g. builtin tensors, i1, IREE
list, etc.) and the `!torch` types.

This made more sense, as I realized that if I didn't factor out
`torch_c` then the Torch dialect would have a dependency on IREE
dialect (we previously didn't notice this was an issue because we only
depended on `builtin` types), which seemed wrong to me.

Recommended review order:
- TorchToIREE.cpp / `TorchToIREE/basic.mlir`
- Look at the new structure of createTorchScriptToNpcompBackendPipeline.
  It now lives in TorchConversion/Transforms/Passes.cpp and cleanly
  calls into `Torch::createTorchScriptToTorchBackendPipeline` for the
  frontend lowering to the Torch backend contract.
- Mechanical change extracting
  `torch_c.{to,from}_{i1,i64,f64,builtin_tensor,iree_list}` into a new
  TorchConversion dialect, and a few passes specific to the lowering
  from the Torch backend contract to the npcomp backend contract.
- Minor fixes to TorchToLinalg.cpp to use unconverted operands (now that
  we convert lists as part of operand materialization, we need to use
  the original operands). Also added test for AtenMaxPool2dOp and fixed
  m_TorchConstantIntList.
- TmpDeleteDeadIREELists pass. Temporary pass for deleting dead IREE lists that
  are created as part of operand materialization for conv/max pool/avg pool ops
  in TorchToLinalg.
pull/283/head
Sean Silva 2021-08-11 14:40:08 -07:00
parent 85ff8b692b
commit cab8d922ec
59 changed files with 1195 additions and 436 deletions

View File

@ -128,6 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
"suffix = '${PYTHON_MODULE_SUFFIX}', " "suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}") "extension = '${PYTHON_MODULE_EXTENSION}")
# Include the iree-dialects external project.
set(LLVM_EXTERNAL_PROJECTS "iree-dialects")
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
# LLVM configuration. # LLVM configuration.
message(STATUS "*** ADDING LLVM ***") message(STATUS "*** ADDING LLVM ***")
add_subdirectory( add_subdirectory(
@ -177,6 +181,8 @@ include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR}) link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
set(NPCOMP_TABLEGEN_ARGS "") set(NPCOMP_TABLEGEN_ARGS "")

View File

@ -22,7 +22,7 @@ using namespace mlir::iree;
void IREEDialect::initialize() { void IREEDialect::initialize() {
addTypes< addTypes<
#define GET_TYPEDEF_LIST #define GET_TYPEDEF_LIST
#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc" #include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
>(); >();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST

View File

@ -104,6 +104,14 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()"; let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()";
} }
def ConvertTorchToIREE : Pass<"convert-torch-to-iree", "FuncOp"> {
let summary = "Convert recognized Torch ops to IREE ops";
let description = [{
TODO
}];
let constructor = "mlir::NPCOMP::createConvertTorchToIREEPass()";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Basicpy conversions // Basicpy conversions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
#define NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToIREEPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H

View File

@ -3,3 +3,4 @@ add_subdirectory(Numpy)
add_subdirectory(Refback) add_subdirectory(Refback)
add_subdirectory(Refbackrt) add_subdirectory(Refbackrt)
add_subdirectory(Torch) add_subdirectory(Torch)
add_subdirectory(TorchConversion)

View File

@ -19,7 +19,18 @@ def Torch_Dialect : Dialect {
This dialect maintains a fairly isomorphic representation with TorchScript. This dialect maintains a fairly isomorphic representation with TorchScript.
TODO: Add more detail here. This dialect also provides transforms that lower it to the
"Torch backend contract", which is an IR form that we present to
later conversions, such as conversion to the npcomp backend contract.
The Torch backend contract significantly simplifies the IR representation
and puts it in a form easier for later lowering to work on. Specifically:
- The TorchScript object graph has been flattened to a list of globals (see
the GlobalizeObjectGraph tranformation).
- Most of the operations have been changed to operate on value-semantic
tensors (see MaximizeValueSemantics)
- The number of op variants have been reduced (see ReduceOpVariants)
- Tensor sizes have been analyzed and static ranks inferred where possible
and propagated throughout the program.
}]; }];
let hasRegionArgAttrVerify = 1; let hasRegionArgAttrVerify = 1;

View File

@ -87,15 +87,16 @@ struct torch_list_construct_op_binder {
: bind_values(bvs) {} : bind_values(bvs) {}
bool match(Operation *op) { bool match(Operation *op) {
if (auto constantNums = dyn_cast<Torch::PrimListConstructOp>(op)) { auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
for (Value value : constantNums.elements()) { if (!listConstruct)
return false;
for (Value value : listConstruct.elements()) {
int64_t num; int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num))) if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num); bind_values.push_back(num);
else else
return false; return false;
} }
}
return true; return true;
} }
}; };

View File

@ -611,156 +611,6 @@ def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
let hasFolder = 1; let hasFolder = 1;
} }
//===----------------------------------------------------------------------===//
// Conversions to builtin types.
//===----------------------------------------------------------------------===//
def Torch_ToBuiltinTensorOp : Torch_Op<"to_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.vtensor` to a `tensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions
between value-semantic and non-value-semantic types.
}];
let arguments = (ins
Torch_ValueTensorType:$operand
);
let results = (outs
AnyTensor:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def Torch_FromBuiltinTensorOp : Torch_Op<"from_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `tensor` to a `!torch.vtensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions
between value-semantic and non-value-semantic types.
}];
let arguments = (ins
AnyTensor:$operand
);
let results = (outs
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def Torch_ToI1Op : Torch_Op<"to_i1", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.bool` to an `i1`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_BoolType:$operand
);
let results = (outs
I1:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def Torch_FromI1Op : Torch_Op<"from_i1", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `i1` to a `!torch.bool`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
I1:$operand
);
let results = (outs
Torch_BoolType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def Torch_ToI64Op : Torch_Op<"to_i64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.int` to an `i64`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_IntType:$operand
);
let results = (outs
I64:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def Torch_FromI64Op : Torch_Op<"from_i64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `i64` to a `!torch.int`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
I64:$operand
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def Torch_ToF64Op : Torch_Op<"to_f64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.float` to an `f64`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_FloatType:$operand
);
let results = (outs
F64:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def Torch_FromF64Op : Torch_Op<"from_f64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `f64` to a `!torch.float`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
F64:$operand
);
let results = (outs
Torch_FloatType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Additional ops used to model TorchScript's Graph's / Node's. // Additional ops used to model TorchScript's Graph's / Node's.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -31,16 +31,14 @@ struct TorchLoweringPipelineOptions
}; };
/// Creates a pipeline that lowers the object graph IR that is produced by /// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract. /// TorchScript import into the form expected by torch-verify-backend-contract.
void createLowerObjectGraphPipeline( void createTorchScriptToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers a flat list of funcs and global slots /// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to /// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by npcomp-verify-backend-contract, in particular /// the form required by torch-verify-backend-contract.
/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to void createGlobalizedModuleToTorchBackendPipeline(
/// linalg, converting torch.prim.* ops to elementary math operations.
void createLowerToNpcompBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
@ -55,14 +53,6 @@ std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass(); std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass();
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>>
createFinalizingBackendTypeConversionPass();
} // namespace Torch } // namespace Torch
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -214,40 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}]; }];
} }
def VerifyInvariantsBeforeBackendLowering
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
let summary = "Verify invariants required by backend lowering";
let constructor =
"mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass()";
let description = [{
This pass checks any invariants needed by the process of lowering the
`torch` dialect to the npcomp backend contract.
The most important invariant is that all tensors should be ranked and have
a known dtype. It is useful to catch this early because it usually
represents a simple bug in RefineTypes, but can manifest as many different
kinds of obscure symptoms during lowering.
}];
}
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors";
let constructor = "mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}
def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
let summary = "Finalizes a partial conversion to builtin tensors";
let constructor =
"mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}
#endif // NPCOMP_TORCH_PASSES #endif // NPCOMP_TORCH_PASSES

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,15 @@
set(LLVM_TARGET_DEFINITIONS TorchConversionOps.td)
mlir_tablegen(TorchConversionOps.h.inc -gen-op-decls)
mlir_tablegen(TorchConversionOps.cpp.inc -gen-op-defs)
mlir_tablegen(TorchConversionDialect.h.inc -gen-dialect-decls -dialect=torch_c)
mlir_tablegen(TorchConversionDialect.cpp.inc -gen-dialect-defs -dialect=torch_c)
add_public_tablegen_target(MLIRTorchConversionOpsIncGen)
add_dependencies(mlir-headers MLIRTorchConversionOpsIncGen)
set(LLVM_TARGET_DEFINITIONS TorchConversionTypes.td)
mlir_tablegen(TorchConversionTypes.h.inc -gen-typedef-decls)
mlir_tablegen(TorchConversionTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTorchConversionTypesIncGen)
add_mlir_doc(TorchConversionDialect TorchConversionDialect TorchConversion/ -gen-dialect-doc)
add_mlir_doc(TorchConversionOps TorchConversionOps TorchConversion/ -gen-op-doc)

View File

@ -0,0 +1,29 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCHCONVERSION_BASE
#define TORCHCONVERSION_BASE
include "mlir/IR/OpBase.td"
def TorchConversion_Dialect : Dialect {
// `torch_conversion` is too verbose.
let name = "torch_c";
let cppNamespace = "::mlir::NPCOMP::TorchConversion";
let description = [{
This dialect contains ops and transforms for converting from the Torch
backend contract to the npcomp backend contract.
This mainly consists of converting ops and types from `torch` dialect
to the mix of dialects of the npcomp backend contract, such as tensor
ops being converted linalg-on-tensors, lists being converted to IREE lists,
and !torch.float being converted to `f64`.
}];
}
#endif // TORCHCONVERSION_BASE

View File

@ -0,0 +1,16 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H
#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H
#include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h.inc"
#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHDIALECT_H

View File

@ -0,0 +1,25 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"
#endif // NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H

View File

@ -0,0 +1,207 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCHCONVERSION_OPS
#define TORCHCONVERSION_OPS
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "iree-dialects/Dialect/IREE/IREEDialect.td"
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
: Op<TorchConversion_Dialect, mnemonic, traits> {
}
//===----------------------------------------------------------------------===//
// Conversions to backend types.
//===----------------------------------------------------------------------===//
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.vtensor` to a `tensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions
between value-semantic and non-value-semantic types.
}];
let arguments = (ins
Torch_ValueTensorType:$operand
);
let results = (outs
AnyTensor:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `tensor` to a `!torch.vtensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions
between value-semantic and non-value-semantic types.
}];
let arguments = (ins
AnyTensor:$operand
);
let results = (outs
Torch_ValueTensorType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.bool` to an `i1`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_BoolType:$operand
);
let results = (outs
I1:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `i1` to a `!torch.bool`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
I1:$operand
);
let results = (outs
Torch_BoolType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.int` to an `i64`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_IntType:$operand
);
let results = (outs
I64:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_FromI64Op : TorchConversion_Op<"from_i64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `i64` to a `!torch.int`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
I64:$operand
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.float` to an `f64`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_FloatType:$operand
);
let results = (outs
F64:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `f64` to a `!torch.float`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
F64:$operand
);
let results = (outs
Torch_FloatType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
// TODO: Verify the element types match.
def TorchConversion_ToIREEListOp : TorchConversion_Op<"to_iree_list", [
]> {
let summary = "Convert a `!torch.list` to a `!iree.list`";
let description = [{
}];
let arguments = (ins
Torch_ListType:$operand
);
let results = (outs
IREE_ListType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def TorchConversion_FromIREEListOp : TorchConversion_Op<"from_iree_list", [
]> {
let summary = "Convert a `!iree.list` to a `!torch.list`";
let description = [{
}];
let arguments = (ins
IREE_ListType:$operand
);
let results = (outs
Torch_ListType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
#endif // TORCHCONVERSION_OPS

View File

@ -6,21 +6,26 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H #ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H #define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace Torch { namespace TorchConversion {
/// Get the dependent dialects which might be involved in a backend type
/// conversion.
void getBackendTypeConversionDependentDialects(DialectRegistry &registry);
/// Set up the provided ConversionTarget and TypeConverter for converting /// Set up the provided ConversionTarget and TypeConverter for converting
/// from `torch` dialect types to the types along the npcomp backend boundary /// from `torch` dialect types to the types along the npcomp backend boundary
/// (which currently consist only of builtin types). /// (which currently consist only of builtin types).
void setupBackendTypeConversion(ConversionTarget &target, void setupBackendTypeConversion(ConversionTarget &target,
TypeConverter &typeConverter); TypeConverter &typeConverter);
} // namespace Torch } // namespace TorchConversion
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_BACKENDTYPECONVERSION_H #endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_BACKENDTYPECONVERSION_H

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPTorchConversionPassIncGen)
add_mlir_doc(Passes NPCOMPTorchConversionTransforms ./ -gen-pass-doc)

View File

@ -0,0 +1,44 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
void createTorchScriptToNpcompBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass();
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>>
createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>> createTmpDeleteDeadIREEListsPass();
} // namespace TorchConversion
/// Registers all Torch transformation passes.
void registerTorchConversionPasses();
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H

View File

@ -0,0 +1,76 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_TORCHCONVERSION_PASSES
#define NPCOMP_TORCHCONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def VerifyInvariantsBeforeBackendLowering
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
let summary = "Verify invariants required by backend lowering";
let constructor =
"mlir::NPCOMP::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()";
let description = [{
This pass checks any invariants needed by the process of lowering the
`torch` dialect to the npcomp backend contract.
The most important invariant is that all tensors should be ranked and have
a known dtype. It is useful to catch this early because it usually
represents a simple bug in RefineTypes, but can manifest as many different
kinds of obscure symptoms during lowering.
TODO: This pass should probably be phrased as checking the
"torch backend contract" and moved to that dialect once we have more
substantial definition definition around what that layer is from an
"allowlist" perspective.
}];
}
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors";
let constructor = "mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}
def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
let summary = "Finalizes a partial conversion to builtin tensors";
let constructor =
"mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}
def TmpDeleteDeadIREELists
: Pass<"torch-tmp-delete-dead-lists", "FuncOp"> {
let summary = "Delete dead !iree.list ops";
let constructor =
"mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass()";
let description = [{
Runs a few patterns to delete dead !iree.list ops until IREE can support
running them. Currently, these will get materialized as part of conversions
for ops like AtenConv2dOp that have list operands, even though they are dead
(for those ops, we pattern match a specific case of static constant lists).
Currently, this will break execution of those tests because the IREE
side of these ops still doesn't work (nor is IREE able to delete them
itself).
TODO: Add support to IREE to run these ops E2E.
TODO: Remove this pass once IREE can run them e2e.
}];
}
#endif // NPCOMP_TORCHCONVERSION_PASSES

View File

@ -17,6 +17,7 @@ add_npcomp_library(NPCOMPCommonBackend
MLIRTensor MLIRTensor
MLIRStandard MLIRStandard
MLIRMath MLIRMath
IREEDialectsIREEDialect
) )
mlir_check_all_link_libraries(NPCOMPCommonBackend) mlir_check_all_link_libraries(NPCOMPCommonBackend)

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -32,6 +33,7 @@ class VerifyBackendContractPass
return type; return type;
return nullptr; return nullptr;
}); });
converter.addConversion([](iree::ListType type) { return type; });
TypeConverter scalarConverter; TypeConverter scalarConverter;
for (TypeConverter *c : {&converter, &scalarConverter}) { for (TypeConverter *c : {&converter, &scalarConverter}) {
c->addConversion([](FloatType type) { return type; }); c->addConversion([](FloatType type) { return type; });
@ -59,8 +61,7 @@ class VerifyBackendContractPass
// Tensor operations should go through linalg and the tensor dialect. // Tensor operations should go through linalg and the tensor dialect.
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes); target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
// DimOp is used to query tensor sizes. target.addDynamicallyLegalDialect<iree::IREEDialect>(opHasLegalTypes);
target.addDynamicallyLegalOp<tensor::DimOp>(opHasLegalTypes);
// AssertOp is used to terminate the program for error guards. // AssertOp is used to terminate the program for error guards.
target.addLegalOp<AssertOp>(); target.addLegalOp<AssertOp>();

View File

@ -32,6 +32,7 @@ add_npcomp_library(NPCOMPInitAll
NPCOMPRefBackend NPCOMPRefBackend
NPCOMPRefbackDialect NPCOMPRefbackDialect
NPCOMPTorchDialect NPCOMPTorchDialect
NPCOMPTorchConversionDialect
NPCOMPRefbackrtDialect NPCOMPRefbackrtDialect
NPCOMPBasicpyDialect NPCOMPBasicpyDialect
NPCOMPBasicpyPasses NPCOMPBasicpyPasses
@ -39,6 +40,7 @@ add_npcomp_library(NPCOMPInitAll
NPCOMPNumpyDialect NPCOMPNumpyDialect
NPCOMPNumpyPasses NPCOMPNumpyPasses
NPCOMPTypingPasses NPCOMPTypingPasses
IREEDialectsIREEDialect
# TODO: We shouldn't need npcomp_conversion_libs here, but we have # TODO: We shouldn't need npcomp_conversion_libs here, but we have
# some dialect transform libraries accumulating into that property. # some dialect transform libraries accumulating into that property.

View File

@ -1,3 +1,4 @@
add_subdirectory(TorchToIREE)
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd) add_subdirectory(TorchToStd)

View File

@ -9,6 +9,7 @@
#include "npcomp/Conversion/Passes.h" #include "npcomp/Conversion/Passes.h"
#include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h"
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" #include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h" #include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h" #include "npcomp/Conversion/TorchToStd/TorchToStd.h"

View File

@ -0,0 +1,19 @@
add_npcomp_conversion_library(NPCOMPTorchToIREE
TorchToIREE.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToIREE
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
NPCOMPTorchDialect
MLIRStandard
IREEDialectsIREEDialect
)

View File

@ -0,0 +1,89 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "../PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
//===----------------------------------------------------------------------===//
// The patterns
//===----------------------------------------------------------------------===//
namespace {
class ConvertPrimListConstructOp
: public OpConversionPattern<PrimListConstructOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimListConstructOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = getTypeConverter()->convertType(op.getType());
auto capacity =
rewriter.create<ConstantIndexOp>(op.getLoc(), op->getNumOperands());
auto ireeList =
rewriter.replaceOpWithNewOp<iree::ListCreateOp>(op, type, capacity);
for (int i = 0, e = operands.size(); i != e; ++i) {
auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
rewriter.create<iree::ListSetOp>(op.getLoc(), ireeList, index,
operands[i]);
}
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// The pass
//===----------------------------------------------------------------------===//
namespace {
class ConvertTorchToIREE : public ConvertTorchToIREEBase<ConvertTorchToIREE> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<iree::IREEDialect>();
target.addLegalDialect<StandardOpsDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
patterns.add<ConvertPrimListConstructOp>(typeConverter, context);
target.addIllegalOp<PrimListConstructOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTorchToIREEPass() {
return std::make_unique<ConvertTorchToIREE>();
}

View File

@ -16,7 +16,8 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -63,7 +64,7 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
// to end. Constant values can be be extracted directly and non constant // to end. Constant values can be be extracted directly and non constant
// list values are not supported. // list values are not supported.
// TODO: loose this constraint when properly support list type // TODO: loose this constraint when properly support list type
static bool isConstantIntListMatching(Value &value, static bool isConstantIntListMatching(Value value,
llvm::SmallVectorImpl<int64_t> &expects) { llvm::SmallVectorImpl<int64_t> &expects) {
llvm::SmallVector<int64_t> intValues; llvm::SmallVector<int64_t> intValues;
if (!matchPattern(value, m_TorchConstantIntList(intValues))) if (!matchPattern(value, m_TorchConstantIntList(intValues)))
@ -171,7 +172,6 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
AtenAdaptiveAvgPool2dOp::Adaptor adaptor(operands); AtenAdaptiveAvgPool2dOp::Adaptor adaptor(operands);
Value outputSize = adaptor.output_size();
Value input = adaptor.self(); /* in form of N*C*H*W */ Value input = adaptor.self(); /* in form of N*C*H*W */
RankedTensorType inputType = input.getType().cast<RankedTensorType>(); RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type elementType = inputType.getElementType(); Type elementType = inputType.getElementType();
@ -183,7 +183,10 @@ public:
return rewriter.notifyMatchFailure(op, "input should be rank 4"); return rewriter.notifyMatchFailure(op, "input should be rank 4");
SmallVector<int64_t, 2> expects{1, 1}; SmallVector<int64_t, 2> expects{1, 1};
if (!isConstantIntListMatching(outputSize, expects)) // Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
if (!isConstantIntListMatching(op.output_size(), expects))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support output_size with H and W both equal to constant 1"); op, "only support output_size with H and W both equal to constant 1");
@ -269,9 +272,6 @@ public:
AtenConv2dOp::Adaptor adaptor(operands); AtenConv2dOp::Adaptor adaptor(operands);
Value input = adaptor.input(); /* in form of N*C*H*W */ Value input = adaptor.input(); /* in form of N*C*H*W */
Value weight = adaptor.weight(); /* in form of F*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */
Value padding = adaptor.padding();
Value stride = adaptor.stride();
Value dilation = adaptor.dilation();
Value groups = adaptor.groups(); Value groups = adaptor.groups();
Type elementType = Type elementType =
@ -291,18 +291,21 @@ public:
Value weightH = getDimOp(rewriter, loc, weight, 2); Value weightH = getDimOp(rewriter, loc, weight, 2);
Value weightW = getDimOp(rewriter, loc, weight, 3); Value weightW = getDimOp(rewriter, loc, weight, 3);
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
llvm::SmallVector<int64_t> paddingInts; llvm::SmallVector<int64_t> paddingInts;
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) { if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support constant padding values"); op, "only support constant padding values");
} }
llvm::SmallVector<int64_t, 2> strideInts; llvm::SmallVector<int64_t, 2> strideInts;
if (!matchPattern(stride, m_TorchConstantIntList(strideInts))) if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int strides"); "only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts; llvm::SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts))) if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int dilations"); "only support constant int dilations");
if (!op.bias().getType().isa<Torch::NoneType>()) if (!op.bias().getType().isa<Torch::NoneType>())
@ -905,30 +908,29 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
AtenMaxPool2dOp::Adaptor adaptor(operands); AtenMaxPool2dOp::Adaptor adaptor(operands);
Value self = adaptor.self(); Value self = adaptor.self();
Value kernelSize = adaptor.kernel_size();
Value stride = adaptor.stride();
Value padding = adaptor.padding();
Value dilation = adaptor.dilation();
Value ceilMode = adaptor.ceil_mode(); Value ceilMode = adaptor.ceil_mode();
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>()) if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type"); op.emitError("unimplemented: non-floating point type");
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
llvm::SmallVector<int64_t, 2> strideInts; llvm::SmallVector<int64_t, 2> strideInts;
if (!matchPattern(stride, m_TorchConstantIntList(strideInts))) if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int strides"); "only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts; llvm::SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts))) if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int dilations"); "only support constant int dilations");
llvm::SmallVector<int64_t, 2> paddingInts; llvm::SmallVector<int64_t, 2> paddingInts;
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int paddings"); "only support constant int paddings");
llvm::SmallVector<int64_t, 2> kernelSizeInts; llvm::SmallVector<int64_t, 2> kernelSizeInts;
if (!matchPattern(kernelSize, m_TorchConstantIntList(kernelSizeInts))) if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(op, "only support kernel size ints"); return rewriter.notifyMatchFailure(op, "only support kernel size ints");
Value falseValue = rewriter.create<ConstantOp>( Value falseValue = rewriter.create<ConstantOp>(
@ -1113,6 +1115,7 @@ public:
registry.insert<math::MathDialect>(); registry.insert<math::MathDialect>();
registry.insert<StandardOpsDialect>(); registry.insert<StandardOpsDialect>();
registry.insert<tensor::TensorDialect>(); registry.insert<tensor::TensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
} }
void runOnOperation() override { void runOnOperation() override {
@ -1123,7 +1126,7 @@ public:
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
target.addIllegalOp<AtenMmOp>(); target.addIllegalOp<AtenMmOp>();

View File

@ -16,4 +16,5 @@ add_npcomp_conversion_library(NPCOMPTorchToSCF
MLIRSCF MLIRSCF
MLIRStandard MLIRStandard
NPCOMPTorchDialect NPCOMPTorchDialect
NPCOMPTorchConversionDialect
) )

View File

@ -13,7 +13,8 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -63,6 +64,7 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
public: public:
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>(); registry.insert<scf::SCFDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
} }
void runOnOperation() override { void runOnOperation() override {
@ -72,7 +74,7 @@ public:
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
target.addIllegalOp<PrimIfOp>(); target.addIllegalOp<PrimIfOp>();

View File

@ -14,7 +14,8 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -93,6 +94,7 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
public: public:
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>(); registry.insert<StandardOpsDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
} }
void runOnOperation() override { void runOnOperation() override {
@ -102,7 +104,7 @@ public:
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
target.addIllegalOp<AtenDimOp>(); target.addIllegalOp<AtenDimOp>();

View File

@ -3,3 +3,4 @@ add_subdirectory(Numpy)
add_subdirectory(Refback) add_subdirectory(Refback)
add_subdirectory(Refbackrt) add_subdirectory(Refbackrt)
add_subdirectory(Torch) add_subdirectory(Torch)
add_subdirectory(TorchConversion)

View File

@ -688,35 +688,6 @@ void CopyToValueTensorOp::getEffects(
effects.emplace_back(MemoryEffects::Read::get(), getOperand()); effects.emplace_back(MemoryEffects::Read::get(), getOperand());
} }
//===----------------------------------------------------------------------===//
// ToBuiltinTensorOp
//===----------------------------------------------------------------------===//
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType =
operands[0].getType().cast<ValueTensorType>().toBuiltinTensor();
if (!resultType)
return failure();
inferredReturnTypes.push_back(resultType);
return success();
}
//===----------------------------------------------------------------------===//
// FromBuiltinTensorOp
//===----------------------------------------------------------------------===//
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(
ValueTensorType::getFromShaped(operands[0].getType().cast<TensorType>()));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantNoneOp // ConstantNoneOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,6 +1,5 @@
add_npcomp_conversion_library(NPCOMPTorchPasses add_npcomp_conversion_library(NPCOMPTorchPasses
AdjustCallingConventions.cpp AdjustCallingConventions.cpp
BackendTypeConversion.cpp
Passes.cpp Passes.cpp
GlobalizeObjectGraph.cpp GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp InlineGlobalSlots.cpp
@ -9,7 +8,6 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
ReduceOpVariants.cpp ReduceOpVariants.cpp
RefinePublicReturn.cpp RefinePublicReturn.cpp
RefineTypes.cpp RefineTypes.cpp
VerifyInvariantsBeforeBackendLowering.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
@ -24,7 +22,5 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
MLIRIR MLIRIR
MLIRPass MLIRPass
NPCOMPTorchDialect NPCOMPTorchDialect
NPCOMPTorchToLinalg
NPCOMPInterfaces NPCOMPInterfaces
MLIRMemRefTransforms
) )

View File

@ -7,15 +7,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
@ -29,16 +22,16 @@ namespace {
void mlir::NPCOMP::registerTorchPasses() { void mlir::NPCOMP::registerTorchPasses() {
::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-npcomp-backend-pipeline", "torchscript-to-torch-backend-pipeline",
"Pipeline lowering torch object graph to npcomp backend format.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::NPCOMP::Torch::createLowerObjectGraphPipeline); mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-globalized-module-to-npcomp-backend-pipeline", "torch-globalized-module-to-torch-backend-pipeline",
"Pipeline lowering to npcomp backend form.", "Pipeline lowering a globalized Torch program to Torch backend form.",
mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline); mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline);
} }
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline( void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit", // When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program, // which can contain numerous functions unrelated to the current program,
@ -66,10 +59,10 @@ void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(
// Incorporate user annotations and remove signature Python-isms. // Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass()); pm.addPass(createAdjustCallingConventionsPass());
createLowerToNpcompBackendPipeline(pm, options); createGlobalizedModuleToTorchBackendPipeline(pm, options);
} }
void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously // General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend // building out the frontend pipeline and also co-developing the backend
@ -140,39 +133,5 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<FuncOp>(createCanonicalizerPass());
} }
//===--------------------------------------------------------------------===// // TODO: VerifyTorchBackendContractPass.
// Lowering ops and the !torch.vtensor type to the npcomp backend contract.
//===--------------------------------------------------------------------===//
// Check some invariants to catch errors in a clear way.
pm.addPass(Torch::createVerifyInvariantsBeforeBackendLoweringPass());
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
if (options.optimize) {
// Clean up any non-canonical code introduced in our linalg lowering.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass());
}
// Finish the type conversion from !torch.vtensor to the builtin tensor type.
pm.addPass(createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>(createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that backends expect.
// This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(CommonBackend::createVerifyBackendContractPass());
} }

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,18 @@
add_npcomp_dialect_library(NPCOMPTorchConversionDialect
TorchConversionDialect.cpp
TorchConversionOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion
DEPENDS
MLIRTorchConversionOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,51 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::TorchConversion;
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TorchConversionInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dialect initialize method.
//===----------------------------------------------------------------------===//
void TorchConversionDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"
>();
addInterfaces<TorchConversionInlinerInterface>();
}

View File

@ -0,0 +1,52 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringMap.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::TorchConversion;
//===----------------------------------------------------------------------===//
// ToBuiltinTensorOp
//===----------------------------------------------------------------------===//
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType =
operands[0].getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
if (!resultType)
return failure();
inferredReturnTypes.push_back(resultType);
return success();
}
//===----------------------------------------------------------------------===//
// FromBuiltinTensorOp
//===----------------------------------------------------------------------===//
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(Torch::ValueTensorType::getFromShaped(
operands[0].getType().cast<TensorType>()));
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"

View File

@ -8,24 +8,36 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h" #include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch; using namespace mlir::NPCOMP::TorchConversion;
void mlir::NPCOMP::TorchConversion::getBackendTypeConversionDependentDialects(
DialectRegistry &registry) {
registry.insert<TorchConversionDialect>();
registry.insert<iree::IREEDialect>();
}
//===----------------------------------------------------------------------===//
// Type conversion setup.
//===----------------------------------------------------------------------===//
static void static void
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
TypeConverter &typeConverter) { TypeConverter &typeConverter) {
target.addLegalOp<Torch::ToBuiltinTensorOp, Torch::FromBuiltinTensorOp>(); target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
TorchConversion::FromBuiltinTensorOp>();
typeConverter.addConversion( typeConverter.addConversion(
[](Torch::ValueTensorType type) -> Optional<Type> { [](Torch::ValueTensorType type) -> Optional<Type> {
return type.toBuiltinTensor(); return type.toBuiltinTensor();
@ -34,10 +46,11 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
ValueRange inputs, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseTensorType>()); assert(inputs[0].getType().isa<Torch::BaseTensorType>());
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]); return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
}); });
auto sourceMaterialization = [](OpBuilder &builder, ValueTensorType type, auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value { ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<TensorType>()); assert(inputs[0].getType().isa<TensorType>());
@ -49,7 +62,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
static void setupTorchBoolToI1Conversion(ConversionTarget &target, static void setupTorchBoolToI1Conversion(ConversionTarget &target,
TypeConverter &typeConverter) { TypeConverter &typeConverter) {
target.addLegalOp<Torch::ToI1Op, Torch::FromI1Op>(); target.addLegalOp<TorchConversion::ToI1Op, TorchConversion::FromI1Op>();
typeConverter.addConversion([](Torch::BoolType type) -> Optional<Type> { typeConverter.addConversion([](Torch::BoolType type) -> Optional<Type> {
return IntegerType::get(type.getContext(), 1); return IntegerType::get(type.getContext(), 1);
}); });
@ -75,7 +88,7 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
static void setupTorchIntToI64Conversion(ConversionTarget &target, static void setupTorchIntToI64Conversion(ConversionTarget &target,
TypeConverter &typeConverter) { TypeConverter &typeConverter) {
target.addLegalOp<Torch::ToI64Op, Torch::FromI64Op>(); target.addLegalOp<TorchConversion::ToI64Op, TorchConversion::FromI64Op>();
typeConverter.addConversion([](Torch::IntType type) -> Optional<Type> { typeConverter.addConversion([](Torch::IntType type) -> Optional<Type> {
return IntegerType::get(type.getContext(), 64); return IntegerType::get(type.getContext(), 64);
}); });
@ -101,7 +114,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
static void setupTorchFloatToF64Conversion(ConversionTarget &target, static void setupTorchFloatToF64Conversion(ConversionTarget &target,
TypeConverter &typeConverter) { TypeConverter &typeConverter) {
target.addLegalOp<Torch::ToF64Op, Torch::FromF64Op>(); target.addLegalOp<TorchConversion::ToF64Op, TorchConversion::FromF64Op>();
typeConverter.addConversion([](Torch::FloatType type) -> Optional<Type> { typeConverter.addConversion([](Torch::FloatType type) -> Optional<Type> {
return Float64Type::get(type.getContext()); return Float64Type::get(type.getContext());
}); });
@ -122,12 +135,38 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
typeConverter.addArgumentMaterialization(sourceMaterialization); typeConverter.addArgumentMaterialization(sourceMaterialization);
} }
void mlir::NPCOMP::Torch::setupBackendTypeConversion( static void setupTorchListToIREEListConversion(ConversionTarget &target,
TypeConverter &typeConverter) {
target.addLegalOp<TorchConversion::ToIREEListOp,
TorchConversion::FromIREEListOp>();
typeConverter.addConversion([&](Torch::ListType type) -> Optional<Type> {
return iree::ListType::get(
type.getContext(), typeConverter.convertType(type.getContainedType()));
});
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, iree::ListType type, ValueRange inputs,
Location loc) -> Optional<Value> {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::ListType>());
return builder.create<ToIREEListOp>(loc, type, inputs[0]).getResult();
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::ListType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<iree::ListType>());
return builder.create<FromIREEListOp>(loc, type, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
typeConverter.addArgumentMaterialization(sourceMaterialization);
}
void mlir::NPCOMP::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) { ConversionTarget &target, TypeConverter &typeConverter) {
setupValueTensorToBuiltinTensorConversion(target, typeConverter); setupValueTensorToBuiltinTensorConversion(target, typeConverter);
setupTorchBoolToI1Conversion(target, typeConverter); setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter); setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter); setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchListToIREEListConversion(target, typeConverter);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -139,6 +178,9 @@ struct FuncBackendTypeConversionPass
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> { : public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
using FuncBackendTypeConversionBase< using FuncBackendTypeConversionBase<
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
auto *context = &getContext(); auto *context = &getContext();
@ -147,7 +189,7 @@ struct FuncBackendTypeConversionPass
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
ConversionTarget target(*context); ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateFuncOpTypeConversionPattern(patterns, typeConverter); populateFuncOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
@ -176,7 +218,7 @@ struct FuncBackendTypeConversionPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createFuncBackendTypeConversionPass() { mlir::NPCOMP::TorchConversion::createFuncBackendTypeConversionPass() {
return std::make_unique<FuncBackendTypeConversionPass>(); return std::make_unique<FuncBackendTypeConversionPass>();
} }
@ -234,13 +276,13 @@ struct FinalizingBackendTypeConversionPass
ConversionTarget target(*context); ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
// Mark materializations as illegal in this pass (since we are finalizing) // Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them. // and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op, setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op>(target, patterns, FromI64Op, ToI64Op, FromF64Op, ToF64Op, FromIREEListOp,
typeConverter); ToIREEListOp>(target, patterns, typeConverter);
// If all result types are legal, and all block arguments are legal, then // If all result types are legal, and all block arguments are legal, then
// all types in the program are legal. // all types in the program are legal.
@ -259,6 +301,6 @@ struct FinalizingBackendTypeConversionPass
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createFinalizingBackendTypeConversionPass() { mlir::NPCOMP::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>(); return std::make_unique<FinalizingBackendTypeConversionPass>();
} }

View File

@ -0,0 +1,28 @@
add_npcomp_conversion_library(NPCOMPTorchConversionPasses
BackendTypeConversion.cpp
Passes.cpp
TmpDeleteDeadIREELists.cpp
VerifyInvariantsBeforeBackendLowering.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/TorchConversion/Transforms
DEPENDS
NPCOMPTorchConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
NPCOMPTorchConversionDialect
NPCOMPTorchDialect
NPCOMPTorchPasses
NPCOMPTorchToIREE
NPCOMPTorchToLinalg
NPCOMPTorchToStd
NPCOMPTorchToSCF
NPCOMPInterfaces
MLIRMemRefTransforms
)

View File

@ -0,0 +1,25 @@
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
namespace TorchConversion {
#define GEN_PASS_CLASSES
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc"
} // namespace TorchConversion
} // namespace NPCOMP
} // end namespace mlir
#endif // NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H

View File

@ -0,0 +1,97 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::registerTorchConversionPasses() {
::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-npcomp-backend-pipeline",
"Pipeline lowering torch object graph to npcomp backend format.",
mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline);
}
void mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the npcomp backend contract starts from the Torch backend
// contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
// Lists and other concepts that don't exist in upstream go through the IREE
// dialect, which we treat as an reasonably well designed interim placeholder
// for the set of ops that we think makes sense in the npcomp backend
// contract. We expect to co-evolve this dialect with npcomp needs, as a lot
// of what we are doing here in npcomp is breaking new ground w.r.t.
// expressiveness and program generality for tensor compilers.
//
// We lower lists last because the lowered form is much harder to reason about
// than the original form.
pm.addNestedPass<FuncOp>(createConvertTorchToIREEPass());
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the npcomp
// backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Temporarily delete dead list ops until IREE can run them e2e.
// TODO: Remove this pass once IREE can run them e2e.
// TODO: Add support to IREE to run these ops E2E.
pm.addNestedPass<FuncOp>(TorchConversion::createTmpDeleteDeadIREEListsPass());
// Verify that we have lowered to the form that npcomp backends expect.
// This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(CommonBackend::createVerifyBackendContractPass());
}

View File

@ -0,0 +1,56 @@
//===- TmpDeleteDeadIREELists.cpp --------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::TorchConversion;
namespace {
class TmpDeleteDeadIREEListsPass
: public TmpDeleteDeadIREEListsBase<TmpDeleteDeadIREEListsPass> {
void runOnOperation() override {
SmallVector<Operation *> toErase;
// Delete lists that are only set (but not read from).
// This is created by our lowering for torch.prim.ListConstruct.
// Until IREE can run such ops e2e (or delete them itself), we need to
// do this cleanup.
// TODO: Add support to IREE to run these ops E2E.
getOperation().walk([&](iree::ListCreateOp op) {
SmallVector<Operation *> deadOps;
deadOps.push_back(op);
for (auto &use : op.getResult().getUses()) {
if (isa<iree::ListSetOp>(use.getOwner())) {
deadOps.push_back(use.getOwner());
} else {
// We can't analyze the list op if it is used by something else.
return;
}
}
llvm::append_range(toErase, deadOps);
});
for (auto *op : toErase) {
op->dropAllDefinedValueUses();
op->erase();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::TorchConversion::createTmpDeleteDeadIREEListsPass() {
return std::make_unique<TmpDeleteDeadIREEListsPass>();
}

View File

@ -11,17 +11,18 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch; using namespace mlir::NPCOMP::TorchConversion;
static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) { static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) {
// TODO: Make this an allowlist instead of a denylist. // TODO: Make this an allowlist instead of a denylist.
// TODO: Make this stricter. // TODO: Make this stricter.
auto type = v.getType(); auto type = v.getType();
if (auto valueTensorType = type.dyn_cast<ValueTensorType>()) { if (auto valueTensorType = type.dyn_cast<Torch::ValueTensorType>()) {
if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes()) if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes())
return errorReportOp->emitError() return errorReportOp->emitError()
.append("unsupported by backend lowering: tensor with unknown rank " .append("unsupported by backend lowering: tensor with unknown rank "
@ -77,7 +78,7 @@ class VerifyInvariantsBeforeBackendLoweringPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>> mlir::NPCOMP::TorchConversion::
mlir::NPCOMP::Torch::createVerifyInvariantsBeforeBackendLoweringPass() { createVerifyInvariantsBeforeBackendLoweringPass() {
return std::make_unique<VerifyInvariantsBeforeBackendLoweringPass>(); return std::make_unique<VerifyInvariantsBeforeBackendLoweringPass>();
} }

View File

@ -8,6 +8,7 @@
#include "npcomp/InitAll.h" #include "npcomp/InitAll.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "npcomp/Backend/Common/Passes.h" #include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Backend/IREE/Passes.h" #include "npcomp/Backend/IREE/Passes.h"
@ -20,6 +21,8 @@
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
#include "npcomp/RefBackend/RefBackend.h" #include "npcomp/RefBackend/RefBackend.h"
#include "npcomp/Typing/Transforms/Passes.h" #include "npcomp/Typing/Transforms/Passes.h"
@ -29,7 +32,9 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry &registry) {
Numpy::NumpyDialect, Numpy::NumpyDialect,
refbackrt::RefbackrtDialect, refbackrt::RefbackrtDialect,
refback::RefbackDialect, refback::RefbackDialect,
mlir::NPCOMP::Torch::TorchDialect>(); mlir::NPCOMP::Torch::TorchDialect,
mlir::NPCOMP::TorchConversion::TorchConversionDialect,
iree::IREEDialect>();
// clang-format on // clang-format on
} }
@ -39,6 +44,7 @@ void mlir::NPCOMP::registerAllPasses() {
mlir::NPCOMP::registerBasicpyPasses(); mlir::NPCOMP::registerBasicpyPasses();
mlir::NPCOMP::registerNumpyPasses(); mlir::NPCOMP::registerNumpyPasses();
mlir::NPCOMP::registerTorchPasses(); mlir::NPCOMP::registerTorchPasses();
mlir::NPCOMP::registerTorchConversionPasses();
mlir::NPCOMP::registerTypingPasses(); mlir::NPCOMP::registerTypingPasses();
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses(); mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses(); mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();

View File

@ -0,0 +1,19 @@
// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: builtin.func @forward(
// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list<!torch.float> {
// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[LIST:.*]] = iree.list.create %[[C2]] : !iree.list<f64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: iree.list.set %[[LIST]][%[[C0]]], %[[ARG]] : !iree.list<f64>, f64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: iree.list.set %[[LIST]][%[[C1]]], %[[ALSO_ARG]] : !iree.list<f64>, f64
// CHECK: %[[LIST_TORCH:.*]] = torch_c.from_iree_list %[[LIST]] : !iree.list<f64> -> !torch.list<!torch.float>
// CHECK: return %[[LIST_TORCH]] : !torch.list<!torch.float>
builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> {
%0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float>
return %0 : !torch.list<!torch.float>
}

View File

@ -3,8 +3,8 @@
// CHECK-LABEL: func @torch.aten.mm$basic( // CHECK-LABEL: func @torch.aten.mm$basic(
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { // CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
// CHECK: %[[LHS:.*]] = torch.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RHS:.*]] = torch.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32> // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
@ -20,7 +20,7 @@
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32> // CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
// CHECK: %[[RESULT_VTENSOR:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32> // CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32> // CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @elementwise$unary( // CHECK-LABEL: func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32> // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) { // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
@ -11,7 +11,7 @@
// CHECK: linalg.yield %[[TANH]] : f32 // CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32> // CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: } // CHECK: }
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
@ -22,8 +22,8 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
// CHECK-LABEL: func @elementwise$binary( // CHECK-LABEL: func @elementwise$binary(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_ARG0:.*]] = torch.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[BUILTIN_ARG1:.*]] = torch.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor<?xf32> // CHECK: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor<?x?xf32> // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
@ -39,7 +39,7 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
// CHECK: linalg.yield %[[MUL]] : f32 // CHECK: linalg.yield %[[MUL]] : f32
// CHECK: } -> tensor<?x?xf32> // CHECK: } -> tensor<?x?xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
@ -61,7 +61,7 @@ func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[BUILTIN_C1:.*]] = torch.to_i64 %[[C1]] // CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]]
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32 // CHECK: %[[ALPHA:.*]] = sitofp %[[BUILTIN_C1]] : i64 to f32

View File

@ -4,10 +4,10 @@
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic( // CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
@ -21,10 +21,10 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative( // CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
@ -38,10 +38,10 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front( // CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
@ -55,10 +55,10 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back( // CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32> // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
@ -72,9 +72,9 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0( // CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {

View File

@ -0,0 +1,27 @@
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: builtin.func @forward
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%false = torch.constant.bool false
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[NEUTRAL:.*]] = constant -1.401300e-45 : f32
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}

View File

@ -5,9 +5,9 @@
// CHECK-LABEL: func @torch.aten.unsqueeze$basic( // CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -17,9 +17,9 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtenso
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative( // CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
@ -29,9 +29,9 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !tor
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front( // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -41,9 +41,9 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back( // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
@ -53,9 +53,9 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle( // CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32> // CHECK: %[[EXPANDED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2

View File

@ -4,15 +4,15 @@
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 // CHECK: %[[VAL_1:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]] // CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) { // CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]] // CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_1]]
// CHECK: scf.yield %[[VAL_5]] : i64 // CHECK: scf.yield %[[VAL_5]] : i64
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]] // CHECK: %[[VAL_6:.*]] = torch_c.to_i64 %[[VAL_2]]
// CHECK: scf.yield %[[VAL_6]] : i64 // CHECK: scf.yield %[[VAL_6]] : i64
// CHECK: } // CHECK: }
// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]] // CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]]
// CHECK: return %[[VAL_7]] : !torch.int // CHECK: return %[[VAL_7]] : !torch.int
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -31,22 +31,22 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 // CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4 // CHECK: %[[VAL_4:.*]] = torch.constant.int 4
// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]] // CHECK: %[[VAL_5:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) { // CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]] // CHECK: %[[VAL_7:.*]] = torch_c.to_i1 %[[VAL_1]]
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) { // CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]] // CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_2]]
// CHECK: scf.yield %[[VAL_9]] : i64 // CHECK: scf.yield %[[VAL_9]] : i64
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]] // CHECK: %[[VAL_10:.*]] = torch_c.to_i64 %[[VAL_3]]
// CHECK: scf.yield %[[VAL_10]] : i64 // CHECK: scf.yield %[[VAL_10]] : i64
// CHECK: } // CHECK: }
// CHECK: scf.yield %[[VAL_11:.*]] : i64 // CHECK: scf.yield %[[VAL_11:.*]] : i64
// CHECK: } else { // CHECK: } else {
// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]] // CHECK: %[[VAL_12:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: scf.yield %[[VAL_12]] : i64 // CHECK: scf.yield %[[VAL_12]] : i64
// CHECK: } // CHECK: }
// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]] // CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]]
// CHECK: return %[[VAL_13]] : !torch.int // CHECK: return %[[VAL_13]] : !torch.int
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int { func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2

View File

@ -3,10 +3,10 @@
// CHECK-LABEL: func @torch.aten.dim( // CHECK-LABEL: func @torch.aten.dim(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32> // CHECK: %[[RANK:.*]] = rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
// CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK]] : index to i64 // CHECK: %[[RANK_I64:.*]] = index_cast %[[RANK]] : index to i64
// CHECK: %[[RANK_TORCH_INT:.*]] = torch.from_i64 %[[RANK_I64]] // CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int // CHECK: return %[[RANK_TORCH_INT]] : !torch.int
func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int %0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int
@ -16,10 +16,10 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK-LABEL: func @torch.aten.ne.int( // CHECK-LABEL: func @torch.aten.ne.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]] // CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[CMP:.*]] = cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP:.*]] = cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
@ -29,10 +29,10 @@ func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK-LABEL: func @torch.aten.gt.int( // CHECK-LABEL: func @torch.aten.gt.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch.to_i64 %[[RHS]] // CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[CMP:.*]] = cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP:.*]] = cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
@ -41,7 +41,7 @@ func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32> // CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32> %0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
@ -50,7 +50,7 @@ func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool { // CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
// CHECK: %[[CST:.*]] = constant true // CHECK: %[[CST:.*]] = constant true
// CHECK: %[[BOOL:.*]] = torch.from_i1 %[[CST]] // CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]]
// CHECK: return %[[BOOL]] : !torch.bool // CHECK: return %[[BOOL]] : !torch.bool
func @torch.constant.bool() -> !torch.bool { func @torch.constant.bool() -> !torch.bool {
%true = torch.constant.bool true %true = torch.constant.bool true
@ -59,7 +59,7 @@ func @torch.constant.bool() -> !torch.bool {
// CHECK-LABEL: func @torch.constant.float() -> !torch.float { // CHECK-LABEL: func @torch.constant.float() -> !torch.float {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64 // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f64
// CHECK: %[[FLOAT:.*]] = torch.from_f64 %[[CST]] // CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]]
// CHECK: return %[[FLOAT]] : !torch.float // CHECK: return %[[FLOAT]] : !torch.float
func @torch.constant.float() -> !torch.float { func @torch.constant.float() -> !torch.float {
%float = torch.constant.float 1.000000e+00 %float = torch.constant.float 1.000000e+00
@ -68,7 +68,7 @@ func @torch.constant.float() -> !torch.float {
// CHECK-LABEL: func @torch.constant.int() -> !torch.int { // CHECK-LABEL: func @torch.constant.int() -> !torch.int {
// CHECK: %[[CST:.*]] = constant 1 : i64 // CHECK: %[[CST:.*]] = constant 1 : i64
// CHECK: %[[INT:.*]] = torch.from_i64 %[[CST]] // CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]]
// CHECK: return %[[INT]] : !torch.int // CHECK: return %[[INT]] : !torch.int
func @torch.constant.int() -> !torch.int { func @torch.constant.int() -> !torch.int {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1

View File

@ -13,19 +13,6 @@ func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) ->
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
} }
// CHECK-LABEL: func @builtin_tensor_interop(
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK: torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
%0 = torch.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
// CHECK: torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
%1 = torch.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
// CHECK: torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
%2 = torch.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
%3 = torch.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
return
}
// CHECK: @tensor.default() -> !torch.tensor // CHECK: @tensor.default() -> !torch.tensor
func private @tensor.default() -> !torch.tensor func private @tensor.default() -> !torch.tensor
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}} // CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}

View File

@ -7,8 +7,8 @@
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: return %[[ARG]] : tensor<f32> // CHECK: return %[[ARG]] : tensor<f32>
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> { func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32> %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
return %1 : tensor<f32> return %1 : tensor<f32>
} }
@ -19,8 +19,8 @@ func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: return %[[ARG]] : i1 // CHECK: return %[[ARG]] : i1
func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 { func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
%0 = torch.from_i1 %arg0 %0 = torch_c.from_i1 %arg0
%1 = torch.to_i1 %0 %1 = torch_c.to_i1 %0
return %1 : i1 return %1 : i1
} }
@ -28,8 +28,8 @@ func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: return %[[ARG]] : i64 // CHECK: return %[[ARG]] : i64
func @eliminate_materializations$torch.int(%arg0: i64) -> i64 { func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
%0 = torch.from_i64 %arg0 %0 = torch_c.from_i64 %arg0
%1 = torch.to_i64 %0 %1 = torch_c.to_i64 %0
return %1 : i64 return %1 : i64
} }
@ -37,24 +37,33 @@ func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: return %[[ARG]] : f64 // CHECK: return %[[ARG]] : f64
func @eliminate_materializations$torch.float(%arg0: f64) -> f64 { func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
%0 = torch.from_f64 %arg0 %0 = torch_c.from_f64 %arg0
%1 = torch.to_f64 %0 %1 = torch_c.to_f64 %0
return %1 : f64 return %1 : f64
} }
// CHECK-LABEL: func @eliminate_materializations$torch.list(
// CHECK-SAME: %[[ARG:.*]]: !iree.list<f64>) -> !iree.list<f64> {
// CHECK: return %[[ARG]] : !iree.list<f64>
func @eliminate_materializations$torch.list(%arg0: !iree.list<f64>) -> !iree.list<f64> {
%0 = torch_c.from_iree_list %arg0 : !iree.list<f64> -> !torch.list<!torch.float>
%1 = torch_c.to_iree_list %0 : !torch.list<!torch.float> -> !iree.list<f64>
return %1 : !iree.list<f64>
}
// ----- // -----
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> { func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}} // expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> !torch.vtensor<[],f32> %0 = "test.source"() : () -> !torch.vtensor<[],f32>
%1 = torch.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32> %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
return %1 : tensor<f32> return %1 : tensor<f32>
} }
// ----- // -----
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) { func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
%0 = torch.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}} // expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> () "test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
return return

View File

@ -5,8 +5,8 @@
// CHECK-LABEL: func @identity( // CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: return %[[MEMREF]] : tensor<f32> // CHECK: return %[[MEMREF]] : tensor<f32>
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
return %arg0 : !torch.vtensor<[],f32> return %arg0 : !torch.vtensor<[],f32>
@ -14,12 +14,12 @@ func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-LABEL: func @block_arguments( // CHECK-LABEL: func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[T1:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: %[[M1:.*]] = torch.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[M1:.*]] = torch_c.to_builtin_tensor %[[T1]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: br ^bb1(%[[M1]] : tensor<f32>) // CHECK: br ^bb1(%[[M1]] : tensor<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>): // CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
// CHECK: %[[T2:.*]] = torch.from_builtin_tensor %[[BBARG]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[BBARG]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: %[[M2:.*]] = torch.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[M2:.*]] = torch_c.to_builtin_tensor %[[T2]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: return %[[M2]] : tensor<f32> // CHECK: return %[[M2]] : tensor<f32>
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
br ^bb1(%arg0: !torch.vtensor<[],f32>) br ^bb1(%arg0: !torch.vtensor<[],f32>)
@ -38,8 +38,8 @@ func @call_source() -> !torch.vtensor<[],f32> {
} }
// CHECK-LABEL: func @call_sink( // CHECK-LABEL: func @call_sink(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
// CHECK: %[[TENSOR:.*]] = torch.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[TENSOR:.*]] = torch_c.from_builtin_tensor %[[ARG]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: call @sink(%[[MEMREF]]) : (tensor<f32>) -> () // CHECK: call @sink(%[[MEMREF]]) : (tensor<f32>) -> ()
// CHECK: return // CHECK: return
func private @sink(!torch.vtensor<[],f32>) func private @sink(!torch.vtensor<[],f32>)
@ -50,7 +50,7 @@ func @call_sink(%arg0: !torch.vtensor<[],f32>) {
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> { // CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> {
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32> // CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
// CHECK: %[[MEMREF:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[MEMREF:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: return %[[MEMREF]] : tensor<f32> // CHECK: return %[[MEMREF]] : tensor<f32>
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> { func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
%0 = "test.source"() : () -> !torch.vtensor<[],f32> %0 = "test.source"() : () -> !torch.vtensor<[],f32>
@ -98,8 +98,8 @@ func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
// CHECK-LABEL: func @identity$torch.bool( // CHECK-LABEL: func @identity$torch.bool(
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: %[[TORCH_BOOL:.*]] = torch.from_i1 %[[ARG]] // CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[ARG]]
// CHECK: %[[I1:.*]] = torch.to_i1 %[[TORCH_BOOL]] // CHECK: %[[I1:.*]] = torch_c.to_i1 %[[TORCH_BOOL]]
// CHECK: return %[[I1]] : i1 // CHECK: return %[[I1]] : i1
func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool { func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
return %arg0 : !torch.bool return %arg0 : !torch.bool
@ -107,8 +107,8 @@ func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
// CHECK-LABEL: func @identity$torch.int( // CHECK-LABEL: func @identity$torch.int(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: %[[TORCH_INT:.*]] = torch.from_i64 %[[ARG]] // CHECK: %[[TORCH_INT:.*]] = torch_c.from_i64 %[[ARG]]
// CHECK: %[[I64:.*]] = torch.to_i64 %[[TORCH_INT]] // CHECK: %[[I64:.*]] = torch_c.to_i64 %[[TORCH_INT]]
// CHECK: return %[[I64]] : i64 // CHECK: return %[[I64]] : i64
func @identity$torch.int(%arg0: !torch.int) -> !torch.int { func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
return %arg0 : !torch.int return %arg0 : !torch.int
@ -116,8 +116,8 @@ func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
// CHECK-LABEL: func @identity$torch.float( // CHECK-LABEL: func @identity$torch.float(
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: %[[TORCH_FLOAT:.*]] = torch.from_f64 %[[ARG]] // CHECK: %[[TORCH_FLOAT:.*]] = torch_c.from_f64 %[[ARG]]
// CHECK: %[[F64:.*]] = torch.to_f64 %[[TORCH_FLOAT]] // CHECK: %[[F64:.*]] = torch_c.to_f64 %[[TORCH_FLOAT]]
// CHECK: return %[[F64]] : f64 // CHECK: return %[[F64]] : f64
func @identity$torch.float(%arg0: !torch.float) -> !torch.float { func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float return %arg0 : !torch.float

View File

@ -0,0 +1,14 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
// CHECK-LABEL: func @builtin_tensor_interop(
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
%0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
// CHECK: torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
%2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
%3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
return
}