mirror of https://github.com/llvm/torch-mlir
Add a torch.kernel_call op and associated predicates.
parent
ba03ecc652
commit
3d74337be0
|
@ -10,6 +10,8 @@
|
|||
#define TORCH_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.td"
|
||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||
|
||||
def Torch_Dialect : Dialect {
|
||||
let name = "torch";
|
||||
|
@ -38,4 +40,60 @@ def Torch_Dialect : Dialect {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Torch has a fairly advanced and featureful Tensor type, and some of the
|
||||
// semantics are important to preserve in a compilation context. In the future,
|
||||
// a dedicated TorchTensor type may be introduced, but also, subsets of cases
|
||||
// and interop are well served by existing tensor-like types, which are
|
||||
// specifically permitted. Typically, on import, constraints are fairly loose
|
||||
// and based on how the program is captured. Settling on and refining to
|
||||
// specific types is done as part of lowering.
|
||||
//
|
||||
// While lowering it is useful to be able to distinguish between mutable and
|
||||
// immutable tensors:
|
||||
// - Mutable tensors can alias.
|
||||
// - Mutable tensors can be a view over another mutable tensor.
|
||||
// - Mutable tensors act as if reference counted and exist for the lifetime
|
||||
// of any reference or derived view.
|
||||
// Conversely, immutable tensors:
|
||||
// - Are normal SSA values representing the contents of the tensor.
|
||||
// - Cannot alias.
|
||||
// - Cannot be a view of any mutable value.
|
||||
// - Have undefined lifetimes.
|
||||
//
|
||||
// At the Torch dialect level, most things are modeled as an AnyTorchTensor;
|
||||
// however, when lowering to specific ops, further constraints are introduced,
|
||||
// necessitating copies, loads, and stores to be inserted to bridge worlds.
|
||||
def AnyTorchImmutableTensor : AnyTypeOf<[
|
||||
// Normal MLIR immutable tensors.
|
||||
AnyTensor,
|
||||
], "allowable torch immutable tensor">;
|
||||
|
||||
def AnyTorchMutableTensor : AnyTypeOf<[
|
||||
// "Numpy-style" mutable NDArray. While not offering the full generality
|
||||
// of a Torch tensor, it models the same access patterns and implies the
|
||||
// same aliasing as Torch tensors.
|
||||
Numpy_NdArrayType,
|
||||
], "allowable torch mutable tensor">;
|
||||
|
||||
def AnyTorchTensorType : AnyTypeOf<[
|
||||
AnyTorchImmutableTensor,
|
||||
AnyTorchMutableTensor,
|
||||
], "Any tensor type legal to pass to a Torch kernel">;
|
||||
|
||||
def AnyScalar : AnyTypeOf<[
|
||||
AnySignedInteger,
|
||||
Basicpy_BoolType,
|
||||
Basicpy_StrType,
|
||||
Basicpy_NoneType,
|
||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||
|
||||
def AnyTorchType : AnyTypeOf<[
|
||||
AnyScalar,
|
||||
AnyTorchTensorType,
|
||||
], "Any type that is legal to pass to a Torch kernel">;
|
||||
|
||||
#endif // TORCH_BASE
|
||||
|
|
|
@ -15,8 +15,28 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
|||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def Torch_DummyOp : Torch_Op<"dummy"> {
|
||||
let summary = "Dummy placeholder op until more is defined";
|
||||
def Torch_KernelCall : Torch_Op<"kernel_call"> {
|
||||
let summary = "Calls a Torch custom kernel";
|
||||
let description = [{
|
||||
Torch kernel calls are matched by the runtime based on signature, including
|
||||
the fully qualified kernel name (i.e. "namespace::name") and the tuple of
|
||||
argument types. This op models such an invocation.
|
||||
|
||||
Caveat: When interacting with the PyTorch boxed interpreter stack, the
|
||||
arguments follow a convention that the top of stack is the first argument.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$kernel_name,
|
||||
Variadic<AnyTorchType>:$args
|
||||
);
|
||||
let results = (outs
|
||||
Variadic<AnyTorchType>:$results
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$kernel_name $args attr-dict `:` functional-type($args, results)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -14,4 +14,6 @@ add_mlir_dialect_library(NPCOMPTorchDialect
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
)
|
||||
|
|
|
@ -8,8 +8,16 @@
|
|||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Basicpy;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
// RUN: npcomp-opt %s | FileCheck %s
|
||||
|
||||
func @dummy() {
|
||||
// CHECK: "torch.dummy"
|
||||
"torch.dummy"() : () -> ()
|
||||
return
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||
|
||||
func @kernel_call(%arg0 : si32, %arg1 : tensor<3x4xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %0 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> tensor<*xf32>
|
||||
%1 = torch.kernel_call "somens::someop" %arg0, %arg1 : (si32, tensor<3x4xf32>) -> (tensor<*xf32>)
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
Loading…
Reference in New Issue